当前位置: 首页>>代码示例>>Python>>正文


Python utils.post_process_prediction方法代码示例

本文整理汇总了Python中fairseq.utils.post_process_prediction方法的典型用法代码示例。如果您正苦于以下问题:Python utils.post_process_prediction方法的具体用法?Python utils.post_process_prediction怎么用?Python utils.post_process_prediction使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在fairseq.utils的用法示例。


在下文中一共展示了utils.post_process_prediction方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: translate

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import post_process_prediction [as 别名]
def translate(self, sent, verbose=False):
        start_id = 0
        if self.args.s == 'zh':
            sent = re.sub(' +', '', sent)
            sent = jieba.tokenize(sent)
            sent = " ".join(s[0] for s in sent)
        inputs = [sent]
        results = []
        for batch in make_batches(inputs, self.args, self.task, self.max_positions, self.encode_fn):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if self.use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            translations = self.task.inference_step(self.generator, self.models, sample)
            for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
            if self.src_dict is not None:
                src_str = self.src_dict.string(src_tokens, self.args.remove_bpe)
                if verbose:
                    print('S-{}\t{}'.format(id, src_str))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), self.args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=self.align_dict,
                    tgt_dict=self.tgt_dict,
                    remove_bpe=self.args.remove_bpe,
                )
                if self.decoder is not None:
                    hypo_str = self.decoder.decode(map(int, hypo_str.strip().split()))
                hypo_str = self.corrector_module(hypo_str)
                if verbose:
                    print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
                    print('P-{}\t{}'.format(
                        id,
                        ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
                    ))
                    if self.args.print_alignment:
                        print('A-{}\t{}'.format(
                            id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))
        return hypo_str 
开发者ID:plkmo,项目名称:NLP_Toolkit,代码行数:60,代码来源:interactive.py

示例2: main

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import post_process_prediction [as 别名]
def main(args):
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
    src_dict, dst_dict = models[0].src_dict, models[0].dst_dict

    print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
    print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
        )

    # Initialize generator
    translator = SequenceGenerator(
        models, beam_size=args.beam, stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
        unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    print('| Type the input sentence and press return:')
    for src_str in sys.stdin:
        src_str = src_str.strip()
        src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        if use_cuda:
            src_tokens = src_tokens.cuda()
        src_lengths = src_tokens.new([src_tokens.numel()])
        translations = translator.generate(
            Variable(src_tokens.view(1, -1)),
            Variable(src_lengths.view(-1)),
        )
        hypos = translations[0]
        print('O\t{}'.format(src_str))

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                dst_dict=dst_dict,
                remove_bpe=args.remove_bpe,
            )
            print('H\t{}\t{}'.format(hypo['score'], hypo_str))
            print('A\t{}'.format(' '.join(map(str, alignment)))) 
开发者ID:EdinburghNLP,项目名称:XSum,代码行数:59,代码来源:interactive.py


注:本文中的fairseq.utils.post_process_prediction方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。