当前位置: 首页>>代码示例>>Python>>正文


Python data.Dataset方法代码示例

本文整理汇总了Python中torchtext.data.Dataset方法的典型用法代码示例。如果您正苦于以下问题:Python data.Dataset方法的具体用法?Python data.Dataset怎么用?Python data.Dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchtext.data的用法示例。


在下文中一共展示了data.Dataset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: splits

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def splits(cls, path, exts, fields, root='.data',
               train='train', validation='val', test='test', **kwargs):
        """Create dataset objects for splits of a TranslationDataset.
        Arguments:
            root: Root dataset storage directory. Default is '.data'.
            exts: A tuple containing the extension to path for each language.
            fields: A tuple containing the fields that will be used for data
                in each language.
            train: The prefix of the train data. Default: 'train'.
            validation: The prefix of the validation data. Default: 'val'.
            test: The prefix of the test data. Default: 'test'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        #path = cls.download(root)

        train_data = None if train is None else cls(
            os.path.join(path, train), exts, fields, **kwargs)
        val_data = None if validation is None else cls(
            os.path.join(path, validation), exts, fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, test), exts, fields, **kwargs)
        return tuple(d for d in (train_data, val_data, test_data)
                     if d is not None) 
开发者ID:nyu-dl,项目名称:dl4mt-nonauto,代码行数:26,代码来源:data.py

示例2: prepare_dataloaders

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def prepare_dataloaders(opt, device):
    batch_size = opt.batch_size
    data = pickle.load(open(opt.data_pkl, 'rb'))

    opt.max_token_seq_len = data['settings'].max_len
    opt.src_pad_idx = data['vocab']['src'].vocab.stoi[Constants.PAD_WORD]
    opt.trg_pad_idx = data['vocab']['trg'].vocab.stoi[Constants.PAD_WORD]

    opt.src_vocab_size = len(data['vocab']['src'].vocab)
    opt.trg_vocab_size = len(data['vocab']['trg'].vocab)

    #========= Preparing Model =========#
    if opt.embs_share_weight:
        assert data['vocab']['src'].vocab.stoi == data['vocab']['trg'].vocab.stoi, \
            'To sharing word embedding the src/trg word2idx table shall be the same.'

    fields = {'src': data['vocab']['src'], 'trg':data['vocab']['trg']}

    train = Dataset(examples=data['train'], fields=fields)
    val = Dataset(examples=data['valid'], fields=fields)

    train_iterator = BucketIterator(train, batch_size=batch_size, device=device, train=True)
    val_iterator = BucketIterator(val, batch_size=batch_size, device=device)

    return train_iterator, val_iterator 
开发者ID:jadore801120,项目名称:attention-is-all-you-need-pytorch,代码行数:27,代码来源:train.py

示例3: extend_vocab

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def extend_vocab(self, *args, vectors=None, cache=None):
        sources = []
        for arg in args:
            if isinstance(arg, data.Dataset):
                sources += [
                    getattr(arg, name)
                    for name, field in arg.fields.items()
                    if field is self
                ]
            else:
                sources.append(arg)

        tokens = set()
        for source in sources:
            for x in source:
                if not self.sequential:
                    tokens.add(x)
                else:
                    tokens.update(x)

        if self.vocab.vectors is not None:
            vectors = MatchingField._get_vector_data(vectors, cache)
            self.vocab.extend_vectors(tokens, vectors) 
开发者ID:anhaidgroup,项目名称:deepmatcher,代码行数:25,代码来源:field.py

示例4: splits

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def splits(cls, text_field, label_field, root='./data',
               train='20news-bydate-train', test='20news-bydate-test',
               **kwargs):
        """Create dataset objects for splits of the 20news dataset.

        Arguments:
            text_field: The field that will be used for the sentence.
            label_field: The field that will be used for label data.

            train: The filename of the train data. Default: 'train.txt'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """

        path = cls.download_or_unzip(root)

        train_data = None if train is None else cls(
            text_field, label_field, os.path.join(path, train), 2000, **kwargs)

        dev_ratio = 0.1
        dev_index = -1 * int(dev_ratio * len(train_data))

        return (cls(text_field, label_field, examples=train_data[:dev_index]),
                cls(text_field, label_field, examples=train_data[dev_index:])) 
开发者ID:xiaobaoonline,项目名称:pytorch-in-action,代码行数:26,代码来源:mydatasets.py

示例5: __init__

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def __init__(self, text_field, label_field, path=None, examples=None, **kwargs):
        """Create an MR dataset instance given a path and fields.
        Arguments:
            text_field: The field that will be used for text data.
            label_field: The field that will be used for label data.
            path: Path to the data file.
            examples: The examples contain all the data.
            Remaining keyword arguments: Passed to the constructor of
                data.Dataset.
        """
        # text_field.preprocessing = data.Pipeline(clean_str)
        fields = [('text', text_field), ('label', label_field)]
        if examples is None:
            path = self.dirname if path is None else path
            examples = []
            with codecs.open(os.path.join(path, 'rt-polarity.neg'),'r','utf8') as f:
                examples += [
                    data.Example.fromlist([line, 'negative'], fields) for line in f]
            with codecs.open(os.path.join(path, 'rt-polarity.pos'),'r','utf8') as f:
                examples += [
                    data.Example.fromlist([line, 'positive'], fields) for line in f]
        super(MR, self).__init__(examples, fields, **kwargs) 
开发者ID:malllabiisc,项目名称:DiPS,代码行数:24,代码来源:classification_datasets.py

示例6: __init__

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def __init__(self, path, text_field, newline_eos=True,
                 encoding='utf-8', **kwargs):
        """Create a LanguageModelingDataset given a path and a field.

        Arguments:
            path: Path to the data file.
            text_field: The field that will be used for text data.
            newline_eos: Whether to add an <eos> token for every newline in the
                data file. Default: True.
            Remaining keyword arguments: Passed to the constructor of
                data.Dataset.
        """
        fields = [('text', text_field)]
        text = []
        with io.open(path, encoding=encoding) as f:
            for line in f:
                text += text_field.preprocess(line)
                if newline_eos:
                    text.append(u'<eos>')

        examples = [data.Example.fromlist([text], fields)]
        super(LanguageModelingDataset, self).__init__(
            examples, fields, **kwargs) 
开发者ID:pytorch,项目名称:text,代码行数:25,代码来源:language_modeling.py

示例7: test_build_vocab_from_dataset

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def test_build_vocab_from_dataset(self):
        nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
                                   init_token="<w>", eos_token="</w>")
        CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
        ex1 = data.Example.fromlist(["aaa bbb c"], [("chars", CHARS)])
        ex2 = data.Example.fromlist(["bbb aaa"], [("chars", CHARS)])
        dataset = data.Dataset([ex1, ex2], [("chars", CHARS)])

        CHARS.build_vocab(dataset, min_freq=2)

        expected = "a b <w> </w> <s> </s> <cunk> <cpad>".split()
        assert len(CHARS.vocab) == len(expected)
        for c in expected:
            assert c in CHARS.vocab.stoi

        expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
        assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs 
开发者ID:pytorch,项目名称:text,代码行数:19,代码来源:test_field.py

示例8: __iter__

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def __iter__(self):
        text = self.dataset[0].text
        TEXT = self.dataset.fields['text']
        TEXT.eos_token = None
        pad_num = int(math.ceil(len(text) / self.batch_size) *
                      self.batch_size - len(text))
        text = text + ([TEXT.pad_token] * pad_num)
        data = TEXT.numericalize([text], device=self.device)
        data = data.view(self.batch_size, -1).contiguous()
        dataset = Dataset(examples=self.dataset.examples,
                          fields=[('text', TEXT), ('target', TEXT)])
        while True:
            for i in range(0, len(self) * self.bptt_len, self.bptt_len):
                self.iterations += 1
                seq_len = self.bptt_len
                yield Batch.fromvars(
                    dataset, self.batch_size,
                    text=data[:, i:i + seq_len],
                    target=data[:, i + 1:i + 1 + seq_len])
            if not self.repeat:
                return 
开发者ID:asyml,项目名称:texar,代码行数:23,代码来源:batchfirst_bptt.py

示例9: __init__

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def __init__(self, dataset: Dataset, batch_size: int, target_names: Optional[List[str]] = None,
                 sort_key: Union[Callable, str] = "sl", max_context_size: int = 130000, backwards=False,
                 **kwargs):
        self.dataset = dataset
        target_names = [target_names] if isinstance(target_names, str) else target_names
        # sort by the first field if no sort key is given
        if sort_key == "cl":
            def sort_key(x):
                """sort examples by largest conversation length length in example"""
                return len(x.roles)
        elif sort_key == 'sl':
            def sort_key(x):
                """sort examples by largest utterance  length in example"""
                return max(x.sl)
        else:
            assert callable(sort_key), "sort_key provided is not a function"
        self.dl = HierarchicalIterator(dataset, batch_size=batch_size, sort_key=sort_key, target_roles=target_names,
                                       max_context_size=max_context_size, **kwargs)
        self.bs = batch_size
        self.iter = 0 
开发者ID:outcastofmusic,项目名称:quick-nlp,代码行数:22,代码来源:torchtext_data_loaders.py

示例10: splits

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True, root='.', **kwargs):
        """Create dataset objects for splits of the MR dataset.

        Arguments:
            text_field: The field that will be used for the sentence.
            label_field: The field that will be used for label data.
            dev_ratio: The ratio that will be used to get split validation dataset.
            shuffle: Whether to shuffle the data before split.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose trees
                subdirectory the data files will be stored.
            train: The filename of the train data. Default: 'train.txt'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        path = cls.download_or_unzip(root)
        examples = cls(text_field, label_field, path=path, **kwargs).examples
        if shuffle: random.shuffle(examples)
        dev_index = -1 * int(dev_ratio*len(examples))

        return (cls(text_field, label_field, examples=examples[:dev_index]),
                cls(text_field, label_field, examples=examples[dev_index:])) 
开发者ID:Shawn1993,项目名称:cnn-text-classification-pytorch,代码行数:24,代码来源:mydatasets.py

示例11: splits

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def splits(cls, text_field, label_field, dev_ratio=.1, shuffle=True ,root='.', **kwargs):
        """Create dataset objects for splits of the MR dataset.

        Arguments:
            text_field: The field that will be used for the sentence.
            label_field: The field that will be used for label data.
            dev_ratio: The ratio that will be used to get split validation dataset.
            shuffle: Whether to shuffle the data before split.
            root: The root directory that the dataset's zip archive will be
                expanded into; therefore the directory in whose trees
                subdirectory the data files will be stored.
            train: The filename of the train data. Default: 'train.txt'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        path = cls.download_or_unzip(root)
        examples = cls(text_field, label_field, path=path, **kwargs).examples
        if shuffle: random.shuffle(examples)
        dev_index = -1 * int(dev_ratio*len(examples))

        return (cls(text_field, label_field, examples=examples[:dev_index]),
                cls(text_field, label_field, examples=examples[dev_index:])) 
开发者ID:srviest,项目名称:char-cnn-text-classification-pytorch,代码行数:24,代码来源:mydatasets.py

示例12: build_vocab

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def build_vocab(self, *args, **kwargs):
        """Add unaligned_token to the list of special symbols."""
        counter = Counter()
        sources = []
        for arg in args:
            if isinstance(arg, data.Dataset):
                sources += [
                    getattr(arg, name)
                    for name, field in arg.fields.items()
                    if field is self
                ]
            else:
                sources.append(arg)
        for sample in sources:
            for x in sample:
                if not self.sequential:
                    x = [x]
                try:
                    counter.update(x)
                except TypeError:
                    counter.update(chain.from_iterable(x))
        specials = list(
            OrderedDict.fromkeys(
                tok
                for tok in [
                    self.unk_token,
                    self.pad_token,
                    self.init_token,
                    self.eos_token,
                    self.unaligned_token,
                ]
                if tok is not None
            )
        )
        self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) 
开发者ID:Unbabel,项目名称:OpenKiwi,代码行数:37,代码来源:qe_field.py

示例13: get_iterator

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                 sort_key=lambda x: len(x.texta)):
        return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key) 
开发者ID:smilelight,项目名称:lightNLP,代码行数:5,代码来源:tool.py

示例14: get_iterator

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
                     sort_key=lambda x: len(x.text), sort_within_batch=True):
        return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key,
                              sort_within_batch=sort_within_batch) 
开发者ID:smilelight,项目名称:lightNLP,代码行数:6,代码来源:tool.py

示例15: get_iterator

# 需要导入模块: from torchtext import data [as 别名]
# 或者: from torchtext.data import Dataset [as 别名]
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE):
        return BucketIterator(dataset, batch_size=batch_size, device=device) 
开发者ID:smilelight,项目名称:lightNLP,代码行数:4,代码来源:tool.py


注:本文中的torchtext.data.Dataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。