当前位置: 首页>>代码示例>>Python>>正文


Python onmt.io方法代码示例

本文整理汇总了Python中onmt.io方法的典型用法代码示例。如果您正苦于以下问题:Python onmt.io方法的具体用法?Python onmt.io怎么用?Python onmt.io使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在onmt的用法示例。


在下文中一共展示了onmt.io方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load_test_model

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_test_model(opt, dummy_opt):
    checkpoint = torch.load(opt.model,
                            map_location=lambda storage, loc: storage)
    fields = onmt.io.load_fields_from_vocab(
        checkpoint['vocab'], data_type=opt.data_type)

    model_opt = checkpoint['opt']
    for arg in dummy_opt:
        if arg not in model_opt:
            model_opt.__dict__[arg] = dummy_opt[arg]

    model = make_base_model(model_opt, fields,
                            use_gpu(opt), checkpoint)
    model.eval()
    model.generator.eval()
    return fields, model, model_opt 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:18,代码来源:ModelConstructor.py

示例2: _next_dataset_iterator

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def _next_dataset_iterator(self, dataset_iter):
        try:
            self.cur_dataset = next(dataset_iter)
        except StopIteration:
            return None

        # We clear `fields` when saving, restore when loading.
        self.cur_dataset.fields = self.fields

        # Sort batch by decreasing lengths of sentence required by pytorch.
        # sort=False means "Use dataset's sortkey instead of iterator's".
        return onmt.io.OrderedIterator(
            dataset=self.cur_dataset, batch_size=self.batch_size,
            batch_size_fn=self.batch_size_fn,
            device=self.device, train=self.is_train,
            sort=False, sort_within_batch=True,
            repeat=False) 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:19,代码来源:train.py

示例3: make_embeddings

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def make_embeddings(opt, word_dict, feature_dicts, for_encoder=True):
    """
    Make an Embeddings instance.
    Args:
        opt: the option in current environment.
        word_dict(Vocab): words dictionary.
        feature_dicts([Vocab], optional): a list of feature dictionary.
        for_encoder(bool): make Embeddings for encoder or decoder?
    """
    if for_encoder:
        embedding_dim = opt.src_word_vec_size  # Equal to opt.word_vec_size, because we set it that way at the start of train.py
    else:
        embedding_dim = opt.tgt_word_vec_size

    word_padding_idx = word_dict.stoi[onmt.io.PAD_WORD]
    num_word_embeddings = len(word_dict)

    return nn.Embedding(num_word_embeddings, embedding_dim, padding_idx=word_padding_idx) 
开发者ID:matthewmackay,项目名称:reversible-rnn,代码行数:20,代码来源:ModelConstructor.py

示例4: drop_checkpoint

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def drop_checkpoint(self, opt, epoch, fields, valid_stats):
        """ Save a resumable checkpoint.

        Args:
            opt (dict): option object
            epoch (int): epoch number
            fields (dict): fields and vocabulary
            valid_stats : statistics of last validation run
        """
        model_state_dict = self.model.state_dict()
        checkpoint = {
            'model': model_state_dict,
            'vocab': onmt.io.save_fields_to_vocab(fields),
            'opt': opt,
            'epoch': epoch,
            'optim': self.optim,
        }

        save_path = os.path.join(opt.save_dir, 'best_checkpoint.pt')
        torch.save(checkpoint, save_path)
        return save_path 
开发者ID:matthewmackay,项目名称:reversible-rnn,代码行数:23,代码来源:Trainer.py

示例5: load_dataset

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_dataset(data_type):
    assert data_type in ["train", "valid"]

    print("Loading %s data from '%s'" % (data_type, opt.data))

    pts = glob.glob(opt.data + '.' + data_type + '.[0-9]*.pt')
    if pts:
        # Multiple onmt.io.*Dataset's, coalesce all.
        # torch.load loads them imemediately, which might eat up
        # too much memory. A lazy load would be better, but later
        # when we create data iterator, it still requires these
        # data to be loaded. So it seams we don't have a good way
        # to avoid this now.
        datasets = []
        for pt in pts:
            datasets.append(torch.load(pt))
        dataset = onmt.io.ONMTDatasetBase.coalesce_datasets(datasets)
    else:
        # Only one onmt.io.*Dataset, simple!
        dataset = torch.load(opt.data + '.' + data_type + '.pt')

    print(' * number of %s sentences: %d' % (data_type, len(dataset)))

    return dataset 
开发者ID:matthewmackay,项目名称:reversible-rnn,代码行数:26,代码来源:train.py

示例6: load_fields

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_fields(train_dataset, valid_dataset, checkpoint):
    data_type = train_dataset.data_type

    fields = onmt.io.load_fields_from_vocab(torch.load(opt.data + '.vocab.pt'), data_type)
    fields = dict([(k, f) for (k, f) in fields.items() if k in train_dataset.examples[0].__dict__])

    # We save fields in vocab.pt, so assign them back to dataset here.
    train_dataset.fields = fields
    valid_dataset.fields = fields

    if opt.train_from:
        print('Loading vocab from checkpoint at %s.' % opt.train_from)
        fields = onmt.io.load_fields_from_vocab(checkpoint['vocab'], data_type)

    if data_type == 'text':
        print(' * vocabulary size. source = %d; target = %d' %
              (len(fields['src'].vocab), len(fields['tgt'].vocab)))
    else:
        print(' * vocabulary size. target = %d' %
              (len(fields['tgt'].vocab)))

    return fields 
开发者ID:matthewmackay,项目名称:reversible-rnn,代码行数:24,代码来源:train.py

示例7: textDataFromString

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def textDataFromString(data, truncate, side):
        with io.StringIO(data) as corpus_file:
            for i, line in enumerate(corpus_file):
                line = line.strip().split()
                if truncate:
                    line = line[:truncate]

                words, feats, n_feats = \
                    TextDataset.extract_text_features(line)

                example_dict = {side: words, "indices": i}
                if feats:
                    prefix = side + "_feat_"
                    example_dict.update((prefix + str(j), f)
                                        for j, f in enumerate(feats))
                yield example_dict, n_feats 
开发者ID:HendrikStrobelt,项目名称:Seq2Seq-Vis,代码行数:18,代码来源:opennmt_model.py

示例8: _next_dataset_iterator

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def _next_dataset_iterator(self, dataset_iter):
        try:
            self.cur_dataset = next(dataset_iter)
        except StopIteration:
            return None

        # We clear `fields` when saving, restore when loading.
        self.cur_dataset.fields = self.fields

        # Sort batch by decreasing lengths of sentence required by pytorch.
        # sort=False means "Use dataset's sortkey instead of iterator's".
        return onmt.io.OrderedIterator(
                dataset=self.cur_dataset, batch_size=self.batch_size,
                batch_size_fn=self.batch_size_fn,
                device=self.device, train=self.is_train,
                sort=False, sort_within_batch=True,
                repeat=False) 
开发者ID:abaheti95,项目名称:DC-NeuralConversation,代码行数:19,代码来源:train.py

示例9: load_fields

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_fields(dataset, data_type, checkpoint):

    fields = onmt.io.load_fields_from_vocab(
                torch.load(opt.data + '.vocab.pt'), data_type)
    fields = dict([(k, f) for (k, f) in fields.items()
                  if k in dataset.examples[0].__dict__])

    if checkpoint is not None:
        print('Loading vocab from checkpoint at %s.' % opt.train_from)
        fields = onmt.io.load_fields_from_vocab(
                    checkpoint['vocab'], data_type)

    if data_type == 'text':
        print(' * vocabulary size. source = %d; target = %d' %
              (len(fields['src'].vocab), len(fields['tgt'].vocab)))
    else:
        print(' * vocabulary size. target = %d' %
              (len(fields['tgt'].vocab)))

    return fields 
开发者ID:abaheti95,项目名称:DC-NeuralConversation,代码行数:22,代码来源:train.py

示例10: dataset_build

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def dataset_build(self, opt):
        fields = onmt.io.get_fields("text", 0, 0)

        if hasattr(opt, 'src_vocab') and len(opt.src_vocab) > 0:
            with codecs.open(opt.src_vocab, 'w', 'utf-8') as f:
                f.write('a\nb\nc\nd\ne\nf\n')
        if hasattr(opt, 'tgt_vocab') and len(opt.tgt_vocab) > 0:
            with codecs.open(opt.tgt_vocab, 'w', 'utf-8') as f:
                f.write('a\nb\nc\nd\ne\nf\n')

        train_data_files = preprocess.build_save_dataset('train', fields, opt)

        preprocess.build_save_vocab(train_data_files, fields, opt)

        preprocess.build_save_dataset('valid', fields, opt)

        # Remove the generated *pt files.
        for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'):
            os.remove(pt)
        if hasattr(opt, 'src_vocab') and os.path.exists(opt.src_vocab):
            os.remove(opt.src_vocab)
        if hasattr(opt, 'tgt_vocab') and os.path.exists(opt.tgt_vocab):
            os.remove(opt.tgt_vocab) 
开发者ID:ratishsp,项目名称:data2text-entity-py,代码行数:25,代码来源:test_preprocess.py

示例11: load_fields

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_fields(dataset, data_type, checkpoint):
    if checkpoint is not None:
        print('Loading vocab from checkpoint at %s.' % opt.train_from)
        fields = onmt.io.load_fields_from_vocab(
            checkpoint['vocab'], data_type)
    else:
        fields = onmt.io.load_fields_from_vocab(
            torch.load(opt.data + '.vocab.pt'), data_type)
    fields = dict([(k, f) for (k, f) in fields.items()
                   if k in dataset.examples[0].__dict__])

    if data_type == 'text':
        print(' * vocabulary size. source = %d; target = %d' %
              (len(fields['src'].vocab), len(fields['tgt'].vocab)))
    else:
        print(' * vocabulary size. target = %d' %
              (len(fields['tgt'].vocab)))

    return fields 
开发者ID:ratishsp,项目名称:data2text-entity-py,代码行数:21,代码来源:train.py

示例12: load_fields

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_fields(dataset, data_type, checkpoint):
    if checkpoint is not None:
        print('Loading vocab from checkpoint at %s.' % opt.train_from)
        fields = onmt.io.load_fields_from_vocab(
            checkpoint['vocab'], data_type)
    else:
        fields = onmt.io.load_fields_from_vocab(
            torch.load(opt.data + '.vocab.pt'), data_type)
    # print(dataset.examples[0].__dict__)
    # print([(k, f) for (k, f) in fields.items()])
    fields = dict([(k, f) for (k, f) in fields.items()
                   if k in dataset.examples[0].__dict__])

    if data_type == 'text' or data_type == 'gcn':
        print(' * vocabulary size. source = %d; target = %d' %
              (len(fields['src'].vocab), len(fields['tgt'].vocab)))
    else:
        print(' * vocabulary size. target = %d' %
              (len(fields['tgt'].vocab)))

    return fields 
开发者ID:diegma,项目名称:graph-2-text,代码行数:23,代码来源:train.py

示例13: forward

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by compying
        source words.

        Args:
           hidden (`FloatTensor`): hidden outputs `[batch*tlen, input_size]`
           attn (`FloatTensor`): attn for each `[batch*tlen, input_size]`
           src_map (`FloatTensor`):
             A sparse indicator matrix mapping each source word to
             its index in the "extended" vocab containing.
             `[src_len, batch, extra_words]`
        """
        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.tgt_dict.stoi[onmt.io.PAD_WORD]] = -float('inf')
        prob = F.softmax(logits)

        # Probability of copying p(z=1) batch.
        p_copy = F.sigmoid(self.linear_copy(hidden))
        # Probibility of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob,  1 - p_copy.expand_as(prob))
        mul_attn = torch.mul(attn, p_copy.expand_as(attn))
        copy_prob = torch.bmm(mul_attn.view(-1, batch, slen)
                              .transpose(0, 1),
                              src_map.transpose(0, 1)).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1) 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:38,代码来源:CopyGenerator.py

示例14: __init__

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def __init__(self, generator, tgt_vocab):
        super(LossComputeBase, self).__init__()
        self.generator = generator
        self.tgt_vocab = tgt_vocab
        self.padding_idx = tgt_vocab.stoi[onmt.io.PAD_WORD] 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:7,代码来源:Loss.py

示例15: make_embeddings

# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def make_embeddings(opt, word_dict, feature_dicts, for_encoder=True):
    """
    Make an Embeddings instance.
    Args:
        opt: the option in current environment.
        word_dict(Vocab): words dictionary.
        feature_dicts([Vocab], optional): a list of feature dictionary.
        for_encoder(bool): make Embeddings for encoder or decoder?
    """
    if for_encoder:
        embedding_dim = opt.src_word_vec_size
    else:
        embedding_dim = opt.tgt_word_vec_size

    word_padding_idx = word_dict.stoi[onmt.io.PAD_WORD]
    num_word_embeddings = len(word_dict)

    feats_padding_idx = [feat_dict.stoi[onmt.io.PAD_WORD]
                         for feat_dict in feature_dicts]
    num_feat_embeddings = [len(feat_dict) for feat_dict in
                           feature_dicts]

    return Embeddings(word_vec_size=embedding_dim,
                      position_encoding=opt.position_encoding,
                      feat_merge=opt.feat_merge,
                      feat_vec_exponent=opt.feat_vec_exponent,
                      feat_vec_size=opt.feat_vec_size,
                      dropout=opt.dropout,
                      word_padding_idx=word_padding_idx,
                      feat_padding_idx=feats_padding_idx,
                      word_vocab_size=num_word_embeddings,
                      feat_vocab_sizes=num_feat_embeddings,
                      sparse=opt.optim == "sparseadam") 
开发者ID:xiadingZ,项目名称:video-caption-openNMT.pytorch,代码行数:35,代码来源:ModelConstructor.py


注:本文中的onmt.io方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。