当前位置: 首页>>代码示例>>Python>>正文


Python utils.load_align_dict方法代码示例

本文整理汇总了Python中fairseq.utils.load_align_dict方法的典型用法代码示例。如果您正苦于以下问题:Python utils.load_align_dict方法的具体用法?Python utils.load_align_dict怎么用?Python utils.load_align_dict使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在fairseq.utils的用法示例。


在下文中一共展示了utils.load_align_dict方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import load_align_dict [as 别名]
def __init__(self, args, task, models):
        super().__init__()
        self.args = args
        self.task = task
        self.models = nn.ModuleList(models)
        self.src_dict = task.source_dictionary
        self.tgt_dict = task.target_dictionary

        # optimize model for generation
        for model in self.models:
            model.make_generation_fast_(
                beamable_mm_beam_size=(
                    None if getattr(args, 'no_beamable_mm', False)
                    else getattr(args, 'beam', 5)
                ),
                need_attn=getattr(args, 'print_alignment', False),
            )

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        self.align_dict = utils.load_align_dict(getattr(args, 'replace_unk', None))

        self.tokenizer = encoders.build_tokenizer(args)
        self.bpe = encoders.build_bpe(args)

        self.max_positions = utils.resolve_max_positions(
            self.task.max_positions(), *[model.max_positions() for model in models]
        )

        # this is useful for determining the device
        self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) 
开发者ID:pytorch,项目名称:fairseq,代码行数:33,代码来源:hub_utils.py

示例2: main

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import load_align_dict [as 别名]
def main(args):
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
    src_dict, dst_dict = models[0].src_dict, models[0].dst_dict

    print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
    print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
        )

    # Initialize generator
    translator = SequenceGenerator(
        models, beam_size=args.beam, stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
        unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    print('| Type the input sentence and press return:')
    for src_str in sys.stdin:
        src_str = src_str.strip()
        src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        if use_cuda:
            src_tokens = src_tokens.cuda()
        src_lengths = src_tokens.new([src_tokens.numel()])
        translations = translator.generate(
            Variable(src_tokens.view(1, -1)),
            Variable(src_lengths.view(-1)),
        )
        hypos = translations[0]
        print('O\t{}'.format(src_str))

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                dst_dict=dst_dict,
                remove_bpe=args.remove_bpe,
            )
            print('H\t{}\t{}'.format(hypo['score'], hypo_str))
            print('A\t{}'.format(' '.join(map(str, alignment)))) 
开发者ID:EdinburghNLP,项目名称:XSum,代码行数:59,代码来源:interactive.py


注:本文中的fairseq.utils.load_align_dict方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。