当前位置: 首页>>代码示例>>Python>>正文


Python utils.resolve_max_positions方法代码示例

本文整理汇总了Python中fairseq.utils.resolve_max_positions方法的典型用法代码示例。如果您正苦于以下问题:Python utils.resolve_max_positions方法的具体用法?Python utils.resolve_max_positions怎么用?Python utils.resolve_max_positions使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在fairseq.utils的用法示例。


在下文中一共展示了utils.resolve_max_positions方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_valid_iterator

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def get_valid_iterator(
        self,
        subset,
    ):
        """Return an EpochBatchIterator over given validation subset for a given epoch."""
        return self.task.get_batch_iterator(
            dataset=self.task.dataset(subset),
            max_tokens=self.args.max_tokens_valid,
            max_sentences=self.args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                self.task.max_positions(),
                self.model.max_positions(),
            ),
            ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=self.args.required_batch_size_multiple,
            seed=self.args.seed,
            num_shards=self.data_parallel_world_size,
            shard_id=self.data_parallel_rank,
            num_workers=self.args.num_workers
        ) 
开发者ID:pytorch,项目名称:fairseq,代码行数:22,代码来源:trainer.py

示例2: make_batches

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def make_batches(self, lines):
        token_lst = [self.task.source_dictionary.encode_line(line, add_if_not_exist=False).long()
                     for line in lines]
        length_lst = torch.LongTensor([tokens.numel() for tokens in token_lst])

        ds = data.TokenBlockDataset(token_lst, length_lst, self.args.tokens_per_sample, pad=self.task.dictionary.pad(),
                                    eos=self.task.dictionary.eos(),
                                    break_mode='eos', include_targets=True)
        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'
        itr = self.task.get_batch_iterator(
            dataset=data.MonolingualDataset(ds, ds.sizes, self.task.dictionary, self.task.target_dictionary,
                                            add_eos_for_other_targets, shuffle=False, targets=self.task.targets),
            max_tokens=self.args.max_tokens or 3000,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(*[
                model.max_positions() for model in self.models
            ]),
            num_shards=self.args.num_shards,
            shard_id=self.args.shard_id,
            ignore_invalid_inputs=True,
            num_workers=self.args.num_workers,
        ).next_epoch_itr(shuffle=False)

        return itr 
开发者ID:kakaobrain,项目名称:helo_word,代码行数:26,代码来源:lm_scorer.py

示例3: get_dummy_batch

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
        """Return a dummy batch with a given number of tokens."""
        src_len, tgt_len = utils.resolve_max_positions(
            (src_len, tgt_len),
            max_positions,
            (self.max_source_positions, self.max_target_positions),
        )
        bsz = max(num_tokens // max(src_len, tgt_len), 1)
        return self.collater([
            {
                'id': i,
                'source': self.src_dict.dummy_sentence(src_len),
                'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
            }
            for i in range(bsz)
        ]) 
开发者ID:kakaobrain,项目名称:helo_word,代码行数:18,代码来源:language_pair_dataset.py

示例4: get_dummy_batch

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
        """Return a dummy batch with a given number of tokens."""
        src_len, tgt_len = utils.resolve_max_positions(
            (src_len, tgt_len),
            max_positions,
            (self.max_source_positions, self.max_target_positions),
        )
        bsz = max(num_tokens // max(src_len, tgt_len), 1)

        src_dummy = self.src_dict.dummy_sentence(src_len)
        tgt_dummy = self.tgt_dict.dummy_sentence(tgt_len)
        return self.collater([
            {
                'id': i,
                'source': {
                    "tokens": src_dummy,
                    "labels": torch.zeros_like(src_dummy),
                },
                'target': {
                    "tokens": tgt_dummy,
                    "labels": torch.zeros_like(tgt_dummy),
                } if self.tgt_dict is not None else None,
            }
            for i in range(bsz)
        ]) 
开发者ID:kakaobrain,项目名称:helo_word,代码行数:27,代码来源:token_labeled_language_pair_dataset.py

示例5: test_resolve_max_positions_with_tuple

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def test_resolve_max_positions_with_tuple(self):
        resolved = utils.resolve_max_positions(None, (2000, 100, 2000), 12000)
        self.assertEqual(resolved, (2000, 100, 2000)) 
开发者ID:pytorch,项目名称:fairseq,代码行数:5,代码来源:test_utils.py

示例6: get_train_iterator

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def get_train_iterator(
        self,
        epoch,
        combine=True,
        load_dataset=True,
        data_selector=None,
        shard_batch_itr=True,
    ):
        """Return an EpochBatchIterator over the training set for a given epoch."""
        if load_dataset:
            logger.info("loading train data for epoch {}".format(epoch))
            self.task.load_dataset(
                self.args.train_subset,
                epoch=epoch,
                combine=combine,
                data_selector=data_selector,
            )
        return self.task.get_batch_iterator(
            dataset=self.task.dataset(self.args.train_subset),
            max_tokens=self.args.max_tokens,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(
                self.task.max_positions(),
                self.model.max_positions(),
                self.args.max_tokens,
            ),
            ignore_invalid_inputs=True,
            required_batch_size_multiple=self.args.required_batch_size_multiple,
            seed=self.args.seed,
            num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
            shard_id=self.data_parallel_rank if shard_batch_itr else 0,
            num_workers=self.args.num_workers,
            epoch=epoch
        ) 
开发者ID:pytorch,项目名称:fairseq,代码行数:36,代码来源:trainer.py

示例7: __init__

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def __init__(self, args, task, model):
        super().__init__()
        self.args = args
        self.task = task
        self.model = model

        self.bpe = encoders.build_bpe(args)

        self.max_positions = min(utils.resolve_max_positions(
            self.task.max_positions(),
            self.model.max_positions(),
        ))

        # this is useful for determining the device
        self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) 
开发者ID:pytorch,项目名称:fairseq,代码行数:17,代码来源:hub_interface.py

示例8: __init__

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def __init__(self, args, task, models):
        super().__init__()
        self.args = args
        self.task = task
        self.models = nn.ModuleList(models)
        self.src_dict = task.source_dictionary
        self.tgt_dict = task.target_dictionary

        # optimize model for generation
        for model in self.models:
            model.make_generation_fast_(
                beamable_mm_beam_size=(
                    None if getattr(args, 'no_beamable_mm', False)
                    else getattr(args, 'beam', 5)
                ),
                need_attn=getattr(args, 'print_alignment', False),
            )

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        self.align_dict = utils.load_align_dict(getattr(args, 'replace_unk', None))

        self.tokenizer = encoders.build_tokenizer(args)
        self.bpe = encoders.build_bpe(args)

        self.max_positions = utils.resolve_max_positions(
            self.task.max_positions(), *[model.max_positions() for model in models]
        )

        # this is useful for determining the device
        self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) 
开发者ID:pytorch,项目名称:fairseq,代码行数:33,代码来源:hub_utils.py

示例9: build_trainer

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def build_trainer(args, task, model, criterion, trainer_class):
    """ Build trainer with provided trainer_class, and set up training state.
    """
    # Build trainer
    trainer = trainer_class(args, task, model, criterion)

    print(
        f"| training on {args.distributed_world_size} total GPUs "
        f"({torch.cuda.device_count()} GPUs locally on this machine).\n"
        f"| max tokens per GPU = {args.max_tokens} and \
        max sentences per GPU = {args.max_sentences}",
        flush=True,
    )

    epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), model.max_positions()
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    )
    return trainer, epoch_itr 
开发者ID:pytorch,项目名称:translate,代码行数:31,代码来源:train.py

示例10: get_dummy_batch

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
        """Return a dummy batch with a given number of tokens."""
        src_len, tgt_len = utils.resolve_max_positions(
            (src_len, tgt_len),
            max_positions,
            (self.max_source_positions, self.max_target_positions),
        )
        return generate_dummy_batch(num_tokens, self.collater, self.src_vocab, self.tgt_vocab, src_len, tgt_len) 
开发者ID:plkmo,项目名称:NLP_Toolkit,代码行数:10,代码来源:noisy_language_pair_dataset.py

示例11: score_sentence

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def score_sentence(self, line):
        # Tokenize the input sentence into a batch of size one.
        tokens = tokenizer.Tokenizer.tokenize(line, self.task.dictionary, add_if_not_exist=False).long()
        lengths = np.array([tokens.numel()])
        ds = data.TokenBlockDataset(tokens, lengths, self.args.tokens_per_sample, pad=self.task.dictionary.pad(), eos=self.task.dictionary.eos(), break_mode=self.args.sample_break_mode, include_targets=True)

        # Create a batch iterator to wrap the data.
        add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'
        itr = self.task.get_batch_iterator(
            dataset=data.MonolingualDataset(ds, ds.sizes, self.task.dictionary, self.task.target_dictionary, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=False, targets=self.task.targets),
            max_tokens=self.args.max_tokens or 3000,
            max_sentences=self.args.max_sentences,
            max_positions=utils.resolve_max_positions(*[
                model.max_positions() for model in self.models 
            ]),
            num_shards=self.args.num_shards,
            shard_id=self.args.shard_id,
            ignore_invalid_inputs=True,
        ).next_epoch_itr(shuffle=False)
        
        # Evaluate the sentence and return the fluency score.
        results = self.scorer.score_batched_itr(itr, cuda=self.use_cuda)
        for _, _, _, hypos in results:
            for hypo in hypos:
                # Ignore words with infinite probability. This can happen when
                # running low-precision inference on the GPU. 
                pos_scores = hypo['positional_scores']
                word_prob = [score for score in pos_scores if score != float('-inf') and score != float('inf')]
                return self._fluency_score(word_prob)
        return 0.0 
开发者ID:rgcottrell,项目名称:pytorch-human-performance-gec,代码行数:32,代码来源:fluency_scorer.py

示例12: validate

# 需要导入模块: from fairseq import utils [as 别名]
# 或者: from fairseq.utils import resolve_max_positions [as 别名]
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args, itr, epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats['loss'].avg)
    return valid_losses 
开发者ID:kakaobrain,项目名称:helo_word,代码行数:51,代码来源:train.py


注:本文中的fairseq.utils.resolve_max_positions方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。