本文整理汇总了Python中torchtext.data.Iterator方法的典型用法代码示例。如果您正苦于以下问题:Python data.Iterator方法的具体用法?Python data.Iterator怎么用?Python data.Iterator使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchtext.data
的用法示例。
在下文中一共展示了data.Iterator方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_batch_iter
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_batch_iter(self):
self.write_test_numerical_features_dataset()
FLOAT = data.Field(use_vocab=False, sequential=False,
dtype=torch.float)
INT = data.Field(use_vocab=False, sequential=False, is_target=True)
TEXT = data.Field(sequential=False)
dst = data.TabularDataset(path=self.test_numerical_features_dataset_path,
format="tsv", skip_header=False,
fields=[("float", FLOAT),
("int", INT),
("text", TEXT)])
TEXT.build_vocab(dst)
itr = data.Iterator(dst, batch_size=2, device=-1, shuffle=False)
fld_order = [k for k, v in dst.fields.items() if
v is not None and not v.is_target]
batch = next(iter(itr))
(x1, x2), y = batch
x = (x1, x2)[fld_order.index("float")]
self.assertEquals(y.data[0], 1)
self.assertEquals(y.data[1], 12)
self.assertAlmostEqual(x.data[0], 0.1, places=4)
self.assertAlmostEqual(x.data[1], 0.5, places=4)
示例2: init_train_set
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def init_train_set(self):
set_all_random_seed(self.config['random_seed'])
train_file_path = self.config['train_file']
print('Loading train set from {}'.format(train_file_path))
self.train_set = tt_data.TabularDataset(path=train_file_path,
format='csv',
fields=[('Id', self.ID),
('Text', self.TEXT),
('Pos1', self.POS),
('Pos2', self.POS),
('Label', self.TRAIN_LABEL)],
skip_header=False)
self.train_iter = tt_data.Iterator(self.train_set,
sort_key=lambda x: len(x.Text),
batch_size=self.config['train_batch_size'],
train=True,
repeat=False,
sort_within_batch=True,
device=self.device)
示例3: init_dev_set
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def init_dev_set(self):
dev_file_path = self.config['dev_file']
print('Loading dev set from {}'.format(dev_file_path))
self.dev_set = tt_data.TabularDataset(path=dev_file_path,
format='csv',
fields=[('Id', self.ID),
('Text', self.TEXT),
('Pos1', self.POS),
('Pos2', self.POS),
('Label', self.LABEL)],
skip_header=False)
self.dev_iter = tt_data.Iterator(self.dev_set,
sort_key=lambda x: len(x.Text),
batch_size=self.config['test_batch_size'],
train=False,
repeat=False,
sort_within_batch=True,
device=self.device)
示例4: init_test_set
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def init_test_set(self):
test_file_path = self.config['test_file']
print('Loading test set {}'.format(test_file_path))
self.test_set = tt_data.TabularDataset(path=test_file_path,
format='csv',
fields=[('Id', self.ID),
('Text', self.TEXT),
('Pos1', self.POS),
('Pos2', self.POS),
('Label', self.LABEL)],
skip_header=False)
self.test_iter = tt_data.Iterator(self.test_set,
sort_key=lambda x: len(x.Text),
batch_size=self.config['test_batch_size'],
train=False,
repeat=False,
sort_within_batch=True,
device=self.device)
示例5: set_data_iter
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def set_data_iter(self, data_type='train', train_mode=True, batch_size=1):
if data_type == 'train':
arg_data_set = self.train_set
elif data_type == 'dev':
arg_data_set = self.dev_set
elif data_type == 'test':
arg_data_set = self.test_set
else:
raise ValueError('Unsupported data_type value {}, must be in [train, dev, test]'.format(data_type))
if train_mode:
arg_repeat, arg_shuffle, arg_sort = True, True, False
else:
arg_repeat, arg_shuffle, arg_sort = False, False, False
# note that batch_size is set to 1
self.env_data_iter = iter(tt_data.Iterator(arg_data_set, batch_size=batch_size, sort_key=lambda x: len(x.Text),
repeat=arg_repeat, shuffle=arg_shuffle, sort=arg_sort,
sort_within_batch=True, device=self.device))
self.env_data_set = arg_data_set
print("Set environment data iterator, data_type='{}', train_mode={}, batch_size={}".format(
data_type, train_mode, batch_size))
示例6: test_subword_trec
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_subword_trec(self):
TEXT = data.SubwordField()
LABEL = data.Field(sequential=False)
RAW = data.Field(sequential=False, use_vocab=False)
raw, = TREC.splits(RAW, LABEL, train=None)
cooked, = TREC.splits(TEXT, LABEL, train=None)
LABEL.build_vocab(cooked)
TEXT.build_vocab(cooked, max_size=100)
TEXT.segment(cooked)
print(cooked[0].text)
batch = next(iter(data.Iterator(cooked, 1, shuffle=False, device=-1)))
self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text)
示例7: get_iterator
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
sort_key=lambda x: len(x.text)):
return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)
示例8: get_batch_iter
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def get_batch_iter(self, batch_size: int):
def sort(data: data.Dataset) -> int:
return len(getattr(data, 'sentence'))
for dataset in self.dataset:
yield data.Iterator(dataset=dataset,
batch_size=batch_size,
sort_key=sort,
train=True,
repeat=False,
device=self.device
)
示例9: make_data_iter
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def make_data_iter(dataset: Dataset,
batch_size: int,
batch_type: str = "sentence",
train: bool = False,
shuffle: bool = False) -> Iterator:
"""
Returns a torchtext iterator for a torchtext dataset.
:param dataset: torchtext dataset containing src and optionally trg
:param batch_size: size of the batches the iterator prepares
:param batch_type: measure batch size by sentence count or by token count
:param train: whether it's training time, when turned off,
bucketing, sorting within batches and shuffling is disabled
:param shuffle: whether to shuffle the data before each epoch
(no effect if set to True for testing)
:return: torchtext iterator
"""
batch_size_fn = token_batch_size_fn if batch_type == "token" else None
if train:
# optionally shuffle and sort during training
data_iter = data.BucketIterator(
repeat=False, sort=False, dataset=dataset,
batch_size=batch_size, batch_size_fn=batch_size_fn,
train=True, sort_within_batch=True,
sort_key=lambda x: len(x.src), shuffle=shuffle)
else:
# don't sort/shuffle for validation/inference
data_iter = data.BucketIterator(
repeat=False, dataset=dataset,
batch_size=batch_size, batch_size_fn=batch_size_fn,
train=False, sort=False)
return data_iter
示例10: test_batch_with_missing_field
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_batch_with_missing_field(self):
# smoke test to see if batches with missing attributes are shown properly
with open(self.test_missing_field_dataset_path, "wt") as f:
f.write("text,label\n1,0")
dst = data.TabularDataset(path=self.test_missing_field_dataset_path,
format="csv", skip_header=True,
fields=[("text", data.Field(use_vocab=False,
sequential=False)),
("label", None)])
itr = data.Iterator(dst, batch_size=64)
str(next(itr.__iter__()))
示例11: test_subword_trec
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_subword_trec(self):
TEXT = data.SubwordField()
LABEL = data.Field(sequential=False)
RAW = data.Field(sequential=False, use_vocab=False)
raw, _ = TREC.splits(RAW, LABEL)
cooked, _ = TREC.splits(TEXT, LABEL)
LABEL.build_vocab(cooked)
TEXT.build_vocab(cooked, max_size=100)
TEXT.segment(cooked)
print(cooked[0].text)
batch = next(iter(data.Iterator(cooked, 1, shuffle=False)))
self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text)
示例12: init_heldout_test_set
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def init_heldout_test_set(self):
# TODO: change this into input arguments
data_dir_path = os.path.dirname(self.config['test_file'])
heldout_test_file_path = os.path.join(data_dir_path, 'nyt_heldout_test.csv')
heldout_test_entitypair_fp = os.path.join(data_dir_path, 'nyt_heldout_test_entitypair.csv')
def read_entity_pair_info(entitypair_file_path):
tmp_df = pd.read_csv(entitypair_file_path, header=None)
tmp_df.columns = ['span1_guid', 'span2_guid', 'span1', 'span2']
entitypair_infos = tmp_df.to_dict(orient='records')
entity_pairs = []
for ep_info in entitypair_infos:
entity_pairs.append((ep_info['span1_guid'], ep_info['span2_guid']))
return entity_pairs
print('Loading heldout test set {}'.format(heldout_test_file_path))
self.heldout_test_set = tt_data.TabularDataset(path=heldout_test_file_path,
format='csv',
fields=[('Id', self.ID),
('Text', self.TEXT),
('Pos1', self.POS),
('Pos2', self.POS),
('Label', self.LABEL)],
skip_header=False)
self.heldout_entity_pairs = read_entity_pair_info(heldout_test_entitypair_fp)
self.heldout_test_iter = tt_data.Iterator(self.heldout_test_set,
sort_key=lambda x: len(x.Text),
batch_size=self.config['test_batch_size'],
train=False,
repeat=False,
sort_within_batch=True,
device=self.device)
示例13: __init__
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def __init__(self, root_dir='data', batch_size=64, use_vector=True):
self.TEXT = Field(sequential=True, use_vocab=True,
tokenize='spacy', lower=True, batch_first=True)
self.LABEL = LabelField(tensor_type=torch.FloatTensor)
vectors = Vectors(name='mr_vocab.txt', cache='./')
dataset_path = os.path.join(root_dir, '{}.tsv')
self.dataset = {}
self.dataloader = {}
for target in ['train', 'dev', 'test']:
self.dataset[target] = TabularDataset(
path=dataset_path.format(target),
format='tsv',
fields=[('text', self.TEXT), ('label', self.LABEL)]
)
if use_vector:
self.TEXT.build_vocab(self.dataset[target], max_size=25000, vectors=vectors)
else:
self.TEXT.build_vocab(self.dataset[target], max_size=25000)
self.LABEL.build_vocab(self.dataset[target])
self.dataloader[target] = Iterator(self.dataset[target],
batch_size=batch_size,
device=None,
repeat=False,
sort_key=lambda x: len(x.text),
shuffle=True)
示例14: test_csv_file_with_header
# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_csv_file_with_header(self):
example_with_header = [("text", "label"),
("HELLO WORLD", "0"),
("goodbye world", "1")]
TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
fields = {
"label": ("label", data.Field(use_vocab=False,
sequential=False)),
"text": ("text", TEXT)
}
for format_, delim in zip(["csv", "tsv"], [",", "\t"]):
with open(self.test_has_header_dataset_path, "wt") as f:
for line in example_with_header:
f.write("{}\n".format(delim.join(line)))
# check that an error is raised here if a non-existent field is specified
with self.assertRaises(ValueError):
data.TabularDataset(
path=self.test_has_header_dataset_path, format=format_,
fields={"non_existent": ("label", data.Field())})
dataset = data.TabularDataset(
path=self.test_has_header_dataset_path, format=format_,
skip_header=False, fields=fields)
TEXT.build_vocab(dataset)
for i, example in enumerate(dataset):
self.assertEqual(example.text,
example_with_header[i + 1][0].lower().split())
self.assertEqual(example.label, example_with_header[i + 1][1])
# check that the vocabulary is built correctly (#225)
expected_freqs = {"hello": 1, "world": 2, "goodbye": 1, "text": 0}
for k, v in expected_freqs.items():
self.assertEqual(TEXT.vocab.freqs[k], v)
data_iter = data.Iterator(dataset, batch_size=1,
sort_within_batch=False, repeat=False)
next(data_iter.__iter__())