当前位置: 首页>>代码示例>>Python>>正文


Python data.Iterator方法代码示例

本文整理汇总了Python中torchtext.data.Iterator方法的典型用法代码示例。如果您正苦于以下问题:Python data.Iterator方法的具体用法?Python data.Iterator怎么用?Python data.Iterator使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchtext.data的用法示例。


在下文中一共展示了data.Iterator方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_batch_iter

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_batch_iter(self):
        self.write_test_numerical_features_dataset()
        FLOAT = data.Field(use_vocab=False, sequential=False,
                           dtype=torch.float)
        INT = data.Field(use_vocab=False, sequential=False, is_target=True)
        TEXT = data.Field(sequential=False)

        dst = data.TabularDataset(path=self.test_numerical_features_dataset_path,
                                  format="tsv", skip_header=False,
                                  fields=[("float", FLOAT),
                                          ("int", INT),
                                          ("text", TEXT)])
        TEXT.build_vocab(dst)
        itr = data.Iterator(dst, batch_size=2, device=-1, shuffle=False)
        fld_order = [k for k, v in dst.fields.items() if
                     v is not None and not v.is_target]
        batch = next(iter(itr))
        (x1, x2), y = batch
        x = (x1, x2)[fld_order.index("float")]
        self.assertEquals(y.data[0], 1)
        self.assertEquals(y.data[1], 12)
        self.assertAlmostEqual(x.data[0], 0.1, places=4)
        self.assertAlmostEqual(x.data[1], 0.5, places=4) 
开发者ID:pytorch,项目名称:text,代码行数:25,代码来源:test_batch.py

示例2: init_train_set

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def init_train_set(self):
        set_all_random_seed(self.config['random_seed'])
        train_file_path = self.config['train_file']
        print('Loading train set from {}'.format(train_file_path))
        self.train_set = tt_data.TabularDataset(path=train_file_path,
                                                format='csv',
                                                fields=[('Id', self.ID),
                                                        ('Text', self.TEXT),
                                                        ('Pos1', self.POS),
                                                        ('Pos2', self.POS),
                                                        ('Label', self.TRAIN_LABEL)],
                                                skip_header=False)
        self.train_iter = tt_data.Iterator(self.train_set,
                                           sort_key=lambda x: len(x.Text),
                                           batch_size=self.config['train_batch_size'],
                                           train=True,
                                           repeat=False,
                                           sort_within_batch=True,
                                           device=self.device) 
开发者ID:thunlp,项目名称:DIAG-NRE,代码行数:21,代码来源:relation_task.py

示例3: init_dev_set

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def init_dev_set(self):
        dev_file_path = self.config['dev_file']
        print('Loading dev set from {}'.format(dev_file_path))
        self.dev_set = tt_data.TabularDataset(path=dev_file_path,
                                              format='csv',
                                              fields=[('Id', self.ID),
                                                      ('Text', self.TEXT),
                                                      ('Pos1', self.POS),
                                                      ('Pos2', self.POS),
                                                      ('Label', self.LABEL)],
                                              skip_header=False)
        self.dev_iter = tt_data.Iterator(self.dev_set,
                                         sort_key=lambda x: len(x.Text),
                                         batch_size=self.config['test_batch_size'],
                                         train=False,
                                         repeat=False,
                                         sort_within_batch=True,
                                         device=self.device) 
开发者ID:thunlp,项目名称:DIAG-NRE,代码行数:20,代码来源:relation_task.py

示例4: init_test_set

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def init_test_set(self):
        test_file_path = self.config['test_file']
        print('Loading test set {}'.format(test_file_path))
        self.test_set = tt_data.TabularDataset(path=test_file_path,
                                               format='csv',
                                               fields=[('Id', self.ID),
                                                       ('Text', self.TEXT),
                                                       ('Pos1', self.POS),
                                                       ('Pos2', self.POS),
                                                       ('Label', self.LABEL)],
                                               skip_header=False)
        self.test_iter = tt_data.Iterator(self.test_set,
                                          sort_key=lambda x: len(x.Text),
                                          batch_size=self.config['test_batch_size'],
                                          train=False,
                                          repeat=False,
                                          sort_within_batch=True,
                                          device=self.device) 
开发者ID:thunlp,项目名称:DIAG-NRE,代码行数:20,代码来源:relation_task.py

示例5: set_data_iter

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def set_data_iter(self, data_type='train', train_mode=True, batch_size=1):
        if data_type == 'train':
            arg_data_set = self.train_set
        elif data_type == 'dev':
            arg_data_set = self.dev_set
        elif data_type == 'test':
            arg_data_set = self.test_set
        else:
            raise ValueError('Unsupported data_type value {}, must be in [train, dev, test]'.format(data_type))

        if train_mode:
            arg_repeat, arg_shuffle, arg_sort = True, True, False
        else:
            arg_repeat, arg_shuffle, arg_sort = False, False, False

        # note that batch_size is set to 1
        self.env_data_iter = iter(tt_data.Iterator(arg_data_set, batch_size=batch_size, sort_key=lambda x: len(x.Text),
                                                   repeat=arg_repeat, shuffle=arg_shuffle, sort=arg_sort,
                                                   sort_within_batch=True, device=self.device))
        self.env_data_set = arg_data_set

        print("Set environment data iterator, data_type='{}', train_mode={}, batch_size={}".format(
            data_type, train_mode, batch_size)) 
开发者ID:thunlp,项目名称:DIAG-NRE,代码行数:25,代码来源:agent_task.py

示例6: test_subword_trec

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_subword_trec(self):
        TEXT = data.SubwordField()
        LABEL = data.Field(sequential=False)
        RAW = data.Field(sequential=False, use_vocab=False)
        raw, = TREC.splits(RAW, LABEL, train=None)
        cooked, = TREC.splits(TEXT, LABEL, train=None)
        LABEL.build_vocab(cooked)
        TEXT.build_vocab(cooked, max_size=100)
        TEXT.segment(cooked)
        print(cooked[0].text)
        batch = next(iter(data.Iterator(cooked, 1, shuffle=False, device=-1)))
        self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text) 
开发者ID:salesforce,项目名称:decaNLP,代码行数:14,代码来源:test_subword.py

示例7: get_iterator

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text)):
        return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key) 
开发者ID:smilelight,项目名称:lightNLP,代码行数:5,代码来源:tool.py

示例8: get_batch_iter

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def get_batch_iter(self, batch_size: int):

        def sort(data: data.Dataset) -> int:
            return len(getattr(data, 'sentence'))

        for dataset in self.dataset:
            yield data.Iterator(dataset=dataset,
                                batch_size=batch_size,
                                sort_key=sort,
                                train=True,
                                repeat=False,
                                device=self.device
                                ) 
开发者ID:shibing624,项目名称:pycorrector,代码行数:15,代码来源:reader.py

示例9: make_data_iter

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def make_data_iter(dataset: Dataset,
                   batch_size: int,
                   batch_type: str = "sentence",
                   train: bool = False,
                   shuffle: bool = False) -> Iterator:
    """
    Returns a torchtext iterator for a torchtext dataset.

    :param dataset: torchtext dataset containing src and optionally trg
    :param batch_size: size of the batches the iterator prepares
    :param batch_type: measure batch size by sentence count or by token count
    :param train: whether it's training time, when turned off,
        bucketing, sorting within batches and shuffling is disabled
    :param shuffle: whether to shuffle the data before each epoch
        (no effect if set to True for testing)
    :return: torchtext iterator
    """

    batch_size_fn = token_batch_size_fn if batch_type == "token" else None

    if train:
        # optionally shuffle and sort during training
        data_iter = data.BucketIterator(
            repeat=False, sort=False, dataset=dataset,
            batch_size=batch_size, batch_size_fn=batch_size_fn,
            train=True, sort_within_batch=True,
            sort_key=lambda x: len(x.src), shuffle=shuffle)
    else:
        # don't sort/shuffle for validation/inference
        data_iter = data.BucketIterator(
            repeat=False, dataset=dataset,
            batch_size=batch_size, batch_size_fn=batch_size_fn,
            train=False, sort=False)

    return data_iter 
开发者ID:joeynmt,项目名称:joeynmt,代码行数:37,代码来源:data.py

示例10: test_batch_with_missing_field

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_batch_with_missing_field(self):
        # smoke test to see if batches with missing attributes are shown properly
        with open(self.test_missing_field_dataset_path, "wt") as f:
            f.write("text,label\n1,0")

        dst = data.TabularDataset(path=self.test_missing_field_dataset_path,
                                  format="csv", skip_header=True,
                                  fields=[("text", data.Field(use_vocab=False,
                                                              sequential=False)),
                                          ("label", None)])
        itr = data.Iterator(dst, batch_size=64)
        str(next(itr.__iter__())) 
开发者ID:pytorch,项目名称:text,代码行数:14,代码来源:test_batch.py

示例11: test_subword_trec

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_subword_trec(self):
        TEXT = data.SubwordField()
        LABEL = data.Field(sequential=False)
        RAW = data.Field(sequential=False, use_vocab=False)
        raw, _ = TREC.splits(RAW, LABEL)
        cooked, _ = TREC.splits(TEXT, LABEL)
        LABEL.build_vocab(cooked)
        TEXT.build_vocab(cooked, max_size=100)
        TEXT.segment(cooked)
        print(cooked[0].text)
        batch = next(iter(data.Iterator(cooked, 1, shuffle=False)))
        self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text) 
开发者ID:pytorch,项目名称:text,代码行数:14,代码来源:test_subword.py

示例12: init_heldout_test_set

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def init_heldout_test_set(self):
        # TODO: change this into input arguments
        data_dir_path = os.path.dirname(self.config['test_file'])
        heldout_test_file_path = os.path.join(data_dir_path, 'nyt_heldout_test.csv')
        heldout_test_entitypair_fp = os.path.join(data_dir_path, 'nyt_heldout_test_entitypair.csv')

        def read_entity_pair_info(entitypair_file_path):
            tmp_df = pd.read_csv(entitypair_file_path, header=None)
            tmp_df.columns = ['span1_guid', 'span2_guid', 'span1', 'span2']
            entitypair_infos = tmp_df.to_dict(orient='records')
            entity_pairs = []
            for ep_info in entitypair_infos:
                entity_pairs.append((ep_info['span1_guid'], ep_info['span2_guid']))

            return entity_pairs

        print('Loading heldout test set {}'.format(heldout_test_file_path))
        self.heldout_test_set = tt_data.TabularDataset(path=heldout_test_file_path,
                                                       format='csv',
                                                       fields=[('Id', self.ID),
                                                               ('Text', self.TEXT),
                                                               ('Pos1', self.POS),
                                                               ('Pos2', self.POS),
                                                               ('Label', self.LABEL)],
                                                       skip_header=False)
        self.heldout_entity_pairs = read_entity_pair_info(heldout_test_entitypair_fp)
        self.heldout_test_iter = tt_data.Iterator(self.heldout_test_set,
                                                  sort_key=lambda x: len(x.Text),
                                                  batch_size=self.config['test_batch_size'],
                                                  train=False,
                                                  repeat=False,
                                                  sort_within_batch=True,
                                                  device=self.device) 
开发者ID:thunlp,项目名称:DIAG-NRE,代码行数:35,代码来源:relation_task.py

示例13: __init__

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def __init__(self, root_dir='data', batch_size=64, use_vector=True):
        self.TEXT = Field(sequential=True, use_vocab=True,
                          tokenize='spacy', lower=True, batch_first=True)
        self.LABEL = LabelField(tensor_type=torch.FloatTensor)
        vectors = Vectors(name='mr_vocab.txt', cache='./')

        dataset_path = os.path.join(root_dir, '{}.tsv')
        self.dataset = {}
        self.dataloader = {}
        for target in ['train', 'dev', 'test']:
            self.dataset[target] = TabularDataset(
                path=dataset_path.format(target),
                format='tsv',
                fields=[('text', self.TEXT), ('label', self.LABEL)]
            )
            if use_vector:
                self.TEXT.build_vocab(self.dataset[target], max_size=25000, vectors=vectors)
            else:
                self.TEXT.build_vocab(self.dataset[target], max_size=25000)

            self.LABEL.build_vocab(self.dataset[target])
            self.dataloader[target] = Iterator(self.dataset[target],
                                               batch_size=batch_size,
                                               device=None,
                                               repeat=False,
                                               sort_key=lambda x: len(x.text),
                                               shuffle=True) 
开发者ID:slaysd,项目名称:pytorch-sentiment-analysis-classification,代码行数:29,代码来源:dataset.py

示例14: test_csv_file_with_header

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Iterator [as 别名]
def test_csv_file_with_header(self):
        example_with_header = [("text", "label"),
                               ("HELLO WORLD", "0"),
                               ("goodbye world", "1")]

        TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
        fields = {
            "label": ("label", data.Field(use_vocab=False,
                                          sequential=False)),
            "text": ("text", TEXT)
        }

        for format_, delim in zip(["csv", "tsv"], [",", "\t"]):
            with open(self.test_has_header_dataset_path, "wt") as f:
                for line in example_with_header:
                    f.write("{}\n".format(delim.join(line)))

            # check that an error is raised here if a non-existent field is specified
            with self.assertRaises(ValueError):
                data.TabularDataset(
                    path=self.test_has_header_dataset_path, format=format_,
                    fields={"non_existent": ("label", data.Field())})

            dataset = data.TabularDataset(
                path=self.test_has_header_dataset_path, format=format_,
                skip_header=False, fields=fields)

            TEXT.build_vocab(dataset)

            for i, example in enumerate(dataset):
                self.assertEqual(example.text,
                                 example_with_header[i + 1][0].lower().split())
                self.assertEqual(example.label, example_with_header[i + 1][1])

            # check that the vocabulary is built correctly (#225)
            expected_freqs = {"hello": 1, "world": 2, "goodbye": 1, "text": 0}
            for k, v in expected_freqs.items():
                self.assertEqual(TEXT.vocab.freqs[k], v)

            data_iter = data.Iterator(dataset, batch_size=1,
                                      sort_within_batch=False, repeat=False)
            next(data_iter.__iter__()) 
开发者ID:pytorch,项目名称:text,代码行数:44,代码来源:test_dataset.py


注:本文中的torchtext.data.Iterator方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。