当前位置: 首页>>代码示例>>Python>>正文


Python datasets.TranslationDataset方法代码示例

本文整理汇总了Python中torchtext.datasets.TranslationDataset方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.TranslationDataset方法的具体用法?Python datasets.TranslationDataset怎么用?Python datasets.TranslationDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchtext.datasets的用法示例。


在下文中一共展示了datasets.TranslationDataset方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: splits

# 需要导入模块: from torchtext import datasets [as 别名]
# 或者: from torchtext.datasets import TranslationDataset [as 别名]
def splits(cls, path, exts, fields, root='.data',
               train='train', validation='val', test='test', **kwargs):
        """Create dataset objects for splits of a TranslationDataset.
        Arguments:
            root: Root dataset storage directory. Default is '.data'.
            exts: A tuple containing the extension to path for each language.
            fields: A tuple containing the fields that will be used for data
                in each language.
            train: The prefix of the train data. Default: 'train'.
            validation: The prefix of the validation data. Default: 'val'.
            test: The prefix of the test data. Default: 'test'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        #path = cls.download(root)

        train_data = None if train is None else cls(
            os.path.join(path, train), exts, fields, **kwargs)
        val_data = None if validation is None else cls(
            os.path.join(path, validation), exts, fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, test), exts, fields, **kwargs)
        return tuple(d for d in (train_data, val_data, test_data)
                     if d is not None) 
开发者ID:nyu-dl,项目名称:dl4mt-nonauto,代码行数:26,代码来源:data.py

示例2: __init__

# 需要导入模块: from torchtext import datasets [as 别名]
# 或者: from torchtext.datasets import TranslationDataset [as 别名]
def __init__(self, path, exts, fields, load_dataset=False, prefix='', **kwargs):
        if not isinstance(fields[0], (tuple, list)):
            fields = [('src', fields[0]), ('trg', fields[1]), ('dec', fields[2])]

        src_path, trg_path, dec_path = tuple(os.path.expanduser(path + x) for x in exts)
        if load_dataset and (os.path.exists(path + '.processed.{}.pt'.format(prefix))):
            examples = torch.load(path + '.processed.{}.pt'.format(prefix))
        else:
            examples = []
            with open(src_path) as src_file, open(trg_path) as trg_file, open(dec_path) as dec_file:
                for src_line, trg_line, dec_line in zip(src_file, trg_file, dec_file):
                    src_line, trg_line, dec_line = src_line.strip(), trg_line.strip(), dec_line.strip()
                    if src_line != '' and trg_line != '' and dec_line != '':
                        examples.append(data.Example.fromlist(
                            [src_line, trg_line, dec_line], fields))
            if load_dataset:
                torch.save(examples, path + '.processed.{}.pt'.format(prefix))

        super(datasets.TranslationDataset, self).__init__(examples, fields, **kwargs) 
开发者ID:nyu-dl,项目名称:dl4mt-nonauto,代码行数:21,代码来源:data.py

示例3: prepare_dataloaders_from_bpe_files

# 需要导入模块: from torchtext import datasets [as 别名]
# 或者: from torchtext.datasets import TranslationDataset [as 别名]
def prepare_dataloaders_from_bpe_files(opt, device):
    batch_size = opt.batch_size
    MIN_FREQ = 2
    if not opt.embs_share_weight:
        raise

    data = pickle.load(open(opt.data_pkl, 'rb'))
    MAX_LEN = data['settings'].max_len
    field = data['vocab']
    fields = (field, field)

    def filter_examples_with_length(x):
        return len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN

    train = TranslationDataset(
        fields=fields,
        path=opt.train_path, 
        exts=('.src', '.trg'),
        filter_pred=filter_examples_with_length)
    val = TranslationDataset(
        fields=fields,
        path=opt.val_path, 
        exts=('.src', '.trg'),
        filter_pred=filter_examples_with_length)

    opt.max_token_seq_len = MAX_LEN + 2
    opt.src_pad_idx = opt.trg_pad_idx = field.vocab.stoi[Constants.PAD_WORD]
    opt.src_vocab_size = opt.trg_vocab_size = len(field.vocab)

    train_iterator = BucketIterator(train, batch_size=batch_size, device=device, train=True)
    val_iterator = BucketIterator(val, batch_size=batch_size, device=device)
    return train_iterator, val_iterator 
开发者ID:jadore801120,项目名称:attention-is-all-you-need-pytorch,代码行数:34,代码来源:train.py


注:本文中的torchtext.datasets.TranslationDataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。