本文整理汇总了Python中torchtext.data.Dataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.Dataset方法的具体用法?Python data.Dataset怎么用?Python data.Dataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchtext.data
的用法示例。
在下文中一共展示了data.Dataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: splits
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def splits(cls, path, exts, fields, root='.data',
train='train', validation='val', test='test', **kwargs):
"""Create dataset objects for splits of a TranslationDataset.
Arguments:
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
train: The prefix of the train data. Default: 'train'.
validation: The prefix of the validation data. Default: 'val'.
test: The prefix of the test data. Default: 'test'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
#path = cls.download(root)
train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), exts, fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)
示例2: prepare_dataloaders
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def prepare_dataloaders(opt, device):
batch_size = opt.batch_size
data = pickle.load(open(opt.data_pkl, 'rb'))
opt.max_token_seq_len = data['settings'].max_len
opt.src_pad_idx = data['vocab']['src'].vocab.stoi[Constants.PAD_WORD]
opt.trg_pad_idx = data['vocab']['trg'].vocab.stoi[Constants.PAD_WORD]
opt.src_vocab_size = len(data['vocab']['src'].vocab)
opt.trg_vocab_size = len(data['vocab']['trg'].vocab)
#========= Preparing Model =========#
if opt.embs_share_weight:
assert data['vocab']['src'].vocab.stoi == data['vocab']['trg'].vocab.stoi, \
'To sharing word embedding the src/trg word2idx table shall be the same.'
fields = {'src': data['vocab']['src'], 'trg':data['vocab']['trg']}
train = Dataset(examples=data['train'], fields=fields)
val = Dataset(examples=data['valid'], fields=fields)
train_iterator = BucketIterator(train, batch_size=batch_size, device=device, train=True)
val_iterator = BucketIterator(val, batch_size=batch_size, device=device)
return train_iterator, val_iterator
示例3: extend_vocab
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def extend_vocab(self, *args, vectors=None, cache=None):
sources = []
for arg in args:
if isinstance(arg, data.Dataset):
sources += [
getattr(arg, name)
for name, field in arg.fields.items()
if field is self
]
else:
sources.append(arg)
tokens = set()
for source in sources:
for x in source:
if not self.sequential:
tokens.add(x)
else:
tokens.update(x)
if self.vocab.vectors is not None:
vectors = MatchingField._get_vector_data(vectors, cache)
self.vocab.extend_vectors(tokens, vectors)
示例4: splits
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def splits(cls, text_field, label_field, root='./data',
train='20news-bydate-train', test='20news-bydate-test',
**kwargs):
"""Create dataset objects for splits of the 20news dataset.
Arguments:
text_field: The field that will be used for the sentence.
label_field: The field that will be used for label data.
train: The filename of the train data. Default: 'train.txt'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download_or_unzip(root)
train_data = None if train is None else cls(
text_field, label_field, os.path.join(path, train), 2000, **kwargs)
dev_ratio = 0.1
dev_index = -1 * int(dev_ratio * len(train_data))
return (cls(text_field, label_field, examples=train_data[:dev_index]),
cls(text_field, label_field, examples=train_data[dev_index:]))
示例5: __init__
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def __init__(self, text_field, label_field, path=None, examples=None, **kwargs):
"""Create an MR dataset instance given a path and fields.
Arguments:
text_field: The field that will be used for text data.
label_field: The field that will be used for label data.
path: Path to the data file.
examples: The examples contain all the data.
Remaining keyword arguments: Passed to the constructor of
data.Dataset.
"""
# text_field.preprocessing = data.Pipeline(clean_str)
fields = [('text', text_field), ('label', label_field)]
if examples is None:
path = self.dirname if path is None else path
examples = []
with codecs.open(os.path.join(path, 'rt-polarity.neg'),'r','utf8') as f:
examples += [
data.Example.fromlist([line, 'negative'], fields) for line in f]
with codecs.open(os.path.join(path, 'rt-polarity.pos'),'r','utf8') as f:
examples += [
data.Example.fromlist([line, 'positive'], fields) for line in f]
super(MR, self).__init__(examples, fields, **kwargs)
示例6: __init__
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def __init__(self, path, text_field, newline_eos=True,
encoding='utf-8', **kwargs):
"""Create a LanguageModelingDataset given a path and a field.
Arguments:
path: Path to the data file.
text_field: The field that will be used for text data.
newline_eos: Whether to add an <eos> token for every newline in the
data file. Default: True.
Remaining keyword arguments: Passed to the constructor of
data.Dataset.
"""
fields = [('text', text_field)]
text = []
with io.open(path, encoding=encoding) as f:
for line in f:
text += text_field.preprocess(line)
if newline_eos:
text.append(u'<eos>')
examples = [data.Example.fromlist([text], fields)]
super(LanguageModelingDataset, self).__init__(
examples, fields, **kwargs)
示例7: test_build_vocab_from_dataset
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def test_build_vocab_from_dataset(self):
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>")
CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
ex1 = data.Example.fromlist(["aaa bbb c"], [("chars", CHARS)])
ex2 = data.Example.fromlist(["bbb aaa"], [("chars", CHARS)])
dataset = data.Dataset([ex1, ex2], [("chars", CHARS)])
CHARS.build_vocab(dataset, min_freq=2)
expected = "a b <w> </w> <s> </s> <cunk> <cpad>".split()
assert len(CHARS.vocab) == len(expected)
for c in expected:
assert c in CHARS.vocab.stoi
expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs
示例8: __iter__
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def __iter__(self):
text = self.dataset[0].text
TEXT = self.dataset.fields['text']
TEXT.eos_token = None
pad_num = int(math.ceil(len(text) / self.batch_size) *
self.batch_size - len(text))
text = text + ([TEXT.pad_token] * pad_num)
data = TEXT.numericalize([text], device=self.device)
data = data.view(self.batch_size, -1).contiguous()
dataset = Dataset(examples=self.dataset.examples,
fields=[('text', TEXT), ('target', TEXT)])
while True:
for i in range(0, len(self) * self.bptt_len, self.bptt_len):
self.iterations += 1
seq_len = self.bptt_len
yield Batch.fromvars(
dataset, self.batch_size,
text=data[:, i:i + seq_len],
target=data[:, i + 1:i + 1 + seq_len])
if not self.repeat:
return
示例9: __init__
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def __init__(self, dataset: Dataset, batch_size: int, target_names: Optional[List[str]] = None,
sort_key: Union[Callable, str] = "sl", max_context_size: int = 130000, backwards=False,
**kwargs):
self.dataset = dataset
target_names = [target_names] if isinstance(target_names, str) else target_names
# sort by the first field if no sort key is given
if sort_key == "cl":
def sort_key(x):
"""sort examples by largest conversation length length in example"""
return len(x.roles)
elif sort_key == 'sl':
def sort_key(x):
"""sort examples by largest utterance length in example"""
return max(x.sl)
else:
assert callable(sort_key), "sort_key provided is not a function"
self.dl = HierarchicalIterator(dataset, batch_size=batch_size, sort_key=sort_key, target_roles=target_names,
max_context_size=max_context_size, **kwargs)
self.bs = batch_size
self.iter = 0
示例10: splits
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True, root='.', **kwargs):
"""Create dataset objects for splits of the MR dataset.
Arguments:
text_field: The field that will be used for the sentence.
label_field: The field that will be used for label data.
dev_ratio: The ratio that will be used to get split validation dataset.
shuffle: Whether to shuffle the data before split.
root: The root directory that the dataset's zip archive will be
expanded into; therefore the directory in whose trees
subdirectory the data files will be stored.
train: The filename of the train data. Default: 'train.txt'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download_or_unzip(root)
examples = cls(text_field, label_field, path=path, **kwargs).examples
if shuffle: random.shuffle(examples)
dev_index = -1 * int(dev_ratio*len(examples))
return (cls(text_field, label_field, examples=examples[:dev_index]),
cls(text_field, label_field, examples=examples[dev_index:]))
示例11: splits
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True ,root='.', **kwargs):
"""Create dataset objects for splits of the MR dataset.
Arguments:
text_field: The field that will be used for the sentence.
label_field: The field that will be used for label data.
dev_ratio: The ratio that will be used to get split validation dataset.
shuffle: Whether to shuffle the data before split.
root: The root directory that the dataset's zip archive will be
expanded into; therefore the directory in whose trees
subdirectory the data files will be stored.
train: The filename of the train data. Default: 'train.txt'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download_or_unzip(root)
examples = cls(text_field, label_field, path=path, **kwargs).examples
if shuffle: random.shuffle(examples)
dev_index = -1 * int(dev_ratio*len(examples))
return (cls(text_field, label_field, examples=examples[:dev_index]),
cls(text_field, label_field, examples=examples[dev_index:]))
示例12: build_vocab
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def build_vocab(self, *args, **kwargs):
"""Add unaligned_token to the list of special symbols."""
counter = Counter()
sources = []
for arg in args:
if isinstance(arg, data.Dataset):
sources += [
getattr(arg, name)
for name, field in arg.fields.items()
if field is self
]
else:
sources.append(arg)
for sample in sources:
for x in sample:
if not self.sequential:
x = [x]
try:
counter.update(x)
except TypeError:
counter.update(chain.from_iterable(x))
specials = list(
OrderedDict.fromkeys(
tok
for tok in [
self.unk_token,
self.pad_token,
self.init_token,
self.eos_token,
self.unaligned_token,
]
if tok is not None
)
)
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
示例13: get_iterator
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
sort_key=lambda x: len(x.texta)):
return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)
示例14: get_iterator
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
sort_key=lambda x: len(x.text), sort_within_batch=True):
return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key,
sort_within_batch=sort_within_batch)
示例15: get_iterator
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE):
return BucketIterator(dataset, batch_size=batch_size, device=device)