当前位置: 首页>>代码示例>>Python>>正文


Python data.EpochBatchIterator方法代码示例

本文整理汇总了Python中fairseq.data.EpochBatchIterator方法的典型用法代码示例。如果您正苦于以下问题:Python data.EpochBatchIterator方法的具体用法?Python data.EpochBatchIterator怎么用?Python data.EpochBatchIterator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在fairseq.data的用法示例。


在下文中一共展示了data.EpochBatchIterator方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: make_batches

# 需要导入模块: from fairseq import data [as 别名]
# 或者: from fairseq.data import EpochBatchIterator [as 别名]
def make_batches(lines, args, src_dict, max_positions):
    tokens = [
        tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        for src_str in lines
    ]
    lengths = np.array([t.numel() for t in tokens])
    itr = data.EpochBatchIterator(
        dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        yield Batch(
            srcs=[lines[i] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
        ), batch['id'] 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:20,代码来源:interactive.py

示例2: make_batches

# 需要导入模块: from fairseq import data [as 别名]
# 或者: from fairseq.data import EpochBatchIterator [as 别名]
def make_batches(inputs_buffer, args, src_dict, ctx_dict, max_positions ):
    ctx_tokens = [
    tokenizer.Tokenizer.tokenize(inputs[1], ctx_dict, add_if_not_exist=False).long()
    for inputs in inputs_buffer
    ]

    tokens = [
        tokenizer.Tokenizer.tokenize(inputs[0], src_dict, add_if_not_exist=False).long()
        for inputs in inputs_buffer
    ]

    src_sizes = np.array([t.numel() for t in tokens])
    ctx_sizes = np.array([t.numel() for t in ctx_tokens])
    #!debug
    if len(max_positions) < 3:
        max_positions += (max_positions[0],)
    itr = data.EpochBatchIterator(
        dataset=data.LanguageTripleDataset(
            src=tokens, src_sizes=src_sizes, src_dict=src_dict,
            ctx=ctx_tokens, ctx_sizes=ctx_sizes, ctx_dict=ctx_dict
            ),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
    ).next_epoch_itr(shuffle=False)

    for batch in itr:
        yield Batch(
            srcs=[inputs_buffer[i][0] for i in batch['id']],
            tokens=batch['net_input']['src_tokens'],
            lengths=batch['net_input']['src_lengths'],
            ctxs=[inputs_buffer[i][1] for i in batch['id']],
            ctx_tokens=batch['net_input']['ctx_tokens'],
            ctx_lengths=batch['net_input']['ctx_lengths']
        ), batch['id'] 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:37,代码来源:interactive_multi.py

示例3: get_trainer_and_epoch_itr

# 需要导入模块: from fairseq import data [as 别名]
# 或者: from fairseq.data import EpochBatchIterator [as 别名]
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
    tokens = torch.LongTensor(list(range(epoch_size)))
    tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
    epoch_itr = data.EpochBatchIterator(
        dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
        max_tokens=1,
    )
    return trainer, epoch_itr 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:11,代码来源:test_train.py

示例4: get_trainer_and_epoch_itr

# 需要导入模块: from fairseq import data [as 别名]
# 或者: from fairseq.data import EpochBatchIterator [as 别名]
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
    tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
    tokens_ds = data.TokenBlockDataset(
        tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
    )
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
    dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
    epoch_itr = data.EpochBatchIterator(
        dataset=dataset,
        collate_fn=dataset.collater,
        batch_sampler=[[i] for i in range(epoch_size)],
    )
    return trainer, epoch_itr 
开发者ID:pytorch,项目名称:fairseq,代码行数:15,代码来源:test_train.py

示例5: main

# 需要导入模块: from fairseq import data [as 别名]
# 或者: from fairseq.data import EpochBatchIterator [as 别名]
def main(args):
    assert args.path is not None, '--path required for evaluation!'

    if args.tokens_per_sample is None:
        args.tokens_per_sample = 1024
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()

    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
        max_sentences=args.max_sentences or 4,
        max_positions=model.max_positions(),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum()
                count += pos_scores.numel()
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss))) 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:58,代码来源:eval_lm.py

示例6: validate

# 需要导入模块: from fairseq import data [as 别名]
# 或者: from fairseq.data import EpochBatchIterator [as 别名]
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = data.EpochBatchIterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=trainer.get_model().max_positions(),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args, itr, epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in ['loss', 'nll_loss', 'sample_size']:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats)

        valid_losses.append(stats['valid_loss'])
    return valid_losses 
开发者ID:nusnlp,项目名称:crosentgec,代码行数:47,代码来源:train.py


注:本文中的fairseq.data.EpochBatchIterator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。