本文整理汇总了Python中torchtext.datasets.TranslationDataset方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.TranslationDataset方法的具体用法?Python datasets.TranslationDataset怎么用?Python datasets.TranslationDataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchtext.datasets
的用法示例。
在下文中一共展示了datasets.TranslationDataset方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: splits
# 需要导入模块: from torchtext import datasets [as 别名]
# 或者: from torchtext.datasets import TranslationDataset [as 别名]
def splits(cls, path, exts, fields, root='.data',
train='train', validation='val', test='test', **kwargs):
"""Create dataset objects for splits of a TranslationDataset.
Arguments:
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
train: The prefix of the train data. Default: 'train'.
validation: The prefix of the validation data. Default: 'val'.
test: The prefix of the test data. Default: 'test'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
#path = cls.download(root)
train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), exts, fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)
示例2: __init__
# 需要导入模块: from torchtext import datasets [as 别名]
# 或者: from torchtext.datasets import TranslationDataset [as 别名]
def __init__(self, path, exts, fields, load_dataset=False, prefix='', **kwargs):
if not isinstance(fields[0], (tuple, list)):
fields = [('src', fields[0]), ('trg', fields[1]), ('dec', fields[2])]
src_path, trg_path, dec_path = tuple(os.path.expanduser(path + x) for x in exts)
if load_dataset and (os.path.exists(path + '.processed.{}.pt'.format(prefix))):
examples = torch.load(path + '.processed.{}.pt'.format(prefix))
else:
examples = []
with open(src_path) as src_file, open(trg_path) as trg_file, open(dec_path) as dec_file:
for src_line, trg_line, dec_line in zip(src_file, trg_file, dec_file):
src_line, trg_line, dec_line = src_line.strip(), trg_line.strip(), dec_line.strip()
if src_line != '' and trg_line != '' and dec_line != '':
examples.append(data.Example.fromlist(
[src_line, trg_line, dec_line], fields))
if load_dataset:
torch.save(examples, path + '.processed.{}.pt'.format(prefix))
super(datasets.TranslationDataset, self).__init__(examples, fields, **kwargs)
示例3: prepare_dataloaders_from_bpe_files
# 需要导入模块: from torchtext import datasets [as 别名]
# 或者: from torchtext.datasets import TranslationDataset [as 别名]
def prepare_dataloaders_from_bpe_files(opt, device):
batch_size = opt.batch_size
MIN_FREQ = 2
if not opt.embs_share_weight:
raise
data = pickle.load(open(opt.data_pkl, 'rb'))
MAX_LEN = data['settings'].max_len
field = data['vocab']
fields = (field, field)
def filter_examples_with_length(x):
return len(vars(x)['src']) <= MAX_LEN and len(vars(x)['trg']) <= MAX_LEN
train = TranslationDataset(
fields=fields,
path=opt.train_path,
exts=('.src', '.trg'),
filter_pred=filter_examples_with_length)
val = TranslationDataset(
fields=fields,
path=opt.val_path,
exts=('.src', '.trg'),
filter_pred=filter_examples_with_length)
opt.max_token_seq_len = MAX_LEN + 2
opt.src_pad_idx = opt.trg_pad_idx = field.vocab.stoi[Constants.PAD_WORD]
opt.src_vocab_size = opt.trg_vocab_size = len(field.vocab)
train_iterator = BucketIterator(train, batch_size=batch_size, device=device, train=True)
val_iterator = BucketIterator(val, batch_size=batch_size, device=device)
return train_iterator, val_iterator