本文整理汇总了Python中onmt.io方法的典型用法代码示例。如果您正苦于以下问题:Python onmt.io方法的具体用法?Python onmt.io怎么用?Python onmt.io使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类onmt
的用法示例。
在下文中一共展示了onmt.io方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_test_model
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_test_model(opt, dummy_opt):
checkpoint = torch.load(opt.model,
map_location=lambda storage, loc: storage)
fields = onmt.io.load_fields_from_vocab(
checkpoint['vocab'], data_type=opt.data_type)
model_opt = checkpoint['opt']
for arg in dummy_opt:
if arg not in model_opt:
model_opt.__dict__[arg] = dummy_opt[arg]
model = make_base_model(model_opt, fields,
use_gpu(opt), checkpoint)
model.eval()
model.generator.eval()
return fields, model, model_opt
示例2: _next_dataset_iterator
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def _next_dataset_iterator(self, dataset_iter):
try:
self.cur_dataset = next(dataset_iter)
except StopIteration:
return None
# We clear `fields` when saving, restore when loading.
self.cur_dataset.fields = self.fields
# Sort batch by decreasing lengths of sentence required by pytorch.
# sort=False means "Use dataset's sortkey instead of iterator's".
return onmt.io.OrderedIterator(
dataset=self.cur_dataset, batch_size=self.batch_size,
batch_size_fn=self.batch_size_fn,
device=self.device, train=self.is_train,
sort=False, sort_within_batch=True,
repeat=False)
示例3: make_embeddings
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def make_embeddings(opt, word_dict, feature_dicts, for_encoder=True):
"""
Make an Embeddings instance.
Args:
opt: the option in current environment.
word_dict(Vocab): words dictionary.
feature_dicts([Vocab], optional): a list of feature dictionary.
for_encoder(bool): make Embeddings for encoder or decoder?
"""
if for_encoder:
embedding_dim = opt.src_word_vec_size # Equal to opt.word_vec_size, because we set it that way at the start of train.py
else:
embedding_dim = opt.tgt_word_vec_size
word_padding_idx = word_dict.stoi[onmt.io.PAD_WORD]
num_word_embeddings = len(word_dict)
return nn.Embedding(num_word_embeddings, embedding_dim, padding_idx=word_padding_idx)
示例4: drop_checkpoint
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def drop_checkpoint(self, opt, epoch, fields, valid_stats):
""" Save a resumable checkpoint.
Args:
opt (dict): option object
epoch (int): epoch number
fields (dict): fields and vocabulary
valid_stats : statistics of last validation run
"""
model_state_dict = self.model.state_dict()
checkpoint = {
'model': model_state_dict,
'vocab': onmt.io.save_fields_to_vocab(fields),
'opt': opt,
'epoch': epoch,
'optim': self.optim,
}
save_path = os.path.join(opt.save_dir, 'best_checkpoint.pt')
torch.save(checkpoint, save_path)
return save_path
示例5: load_dataset
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_dataset(data_type):
assert data_type in ["train", "valid"]
print("Loading %s data from '%s'" % (data_type, opt.data))
pts = glob.glob(opt.data + '.' + data_type + '.[0-9]*.pt')
if pts:
# Multiple onmt.io.*Dataset's, coalesce all.
# torch.load loads them imemediately, which might eat up
# too much memory. A lazy load would be better, but later
# when we create data iterator, it still requires these
# data to be loaded. So it seams we don't have a good way
# to avoid this now.
datasets = []
for pt in pts:
datasets.append(torch.load(pt))
dataset = onmt.io.ONMTDatasetBase.coalesce_datasets(datasets)
else:
# Only one onmt.io.*Dataset, simple!
dataset = torch.load(opt.data + '.' + data_type + '.pt')
print(' * number of %s sentences: %d' % (data_type, len(dataset)))
return dataset
示例6: load_fields
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_fields(train_dataset, valid_dataset, checkpoint):
data_type = train_dataset.data_type
fields = onmt.io.load_fields_from_vocab(torch.load(opt.data + '.vocab.pt'), data_type)
fields = dict([(k, f) for (k, f) in fields.items() if k in train_dataset.examples[0].__dict__])
# We save fields in vocab.pt, so assign them back to dataset here.
train_dataset.fields = fields
valid_dataset.fields = fields
if opt.train_from:
print('Loading vocab from checkpoint at %s.' % opt.train_from)
fields = onmt.io.load_fields_from_vocab(checkpoint['vocab'], data_type)
if data_type == 'text':
print(' * vocabulary size. source = %d; target = %d' %
(len(fields['src'].vocab), len(fields['tgt'].vocab)))
else:
print(' * vocabulary size. target = %d' %
(len(fields['tgt'].vocab)))
return fields
示例7: textDataFromString
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def textDataFromString(data, truncate, side):
with io.StringIO(data) as corpus_file:
for i, line in enumerate(corpus_file):
line = line.strip().split()
if truncate:
line = line[:truncate]
words, feats, n_feats = \
TextDataset.extract_text_features(line)
example_dict = {side: words, "indices": i}
if feats:
prefix = side + "_feat_"
example_dict.update((prefix + str(j), f)
for j, f in enumerate(feats))
yield example_dict, n_feats
示例8: _next_dataset_iterator
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def _next_dataset_iterator(self, dataset_iter):
try:
self.cur_dataset = next(dataset_iter)
except StopIteration:
return None
# We clear `fields` when saving, restore when loading.
self.cur_dataset.fields = self.fields
# Sort batch by decreasing lengths of sentence required by pytorch.
# sort=False means "Use dataset's sortkey instead of iterator's".
return onmt.io.OrderedIterator(
dataset=self.cur_dataset, batch_size=self.batch_size,
batch_size_fn=self.batch_size_fn,
device=self.device, train=self.is_train,
sort=False, sort_within_batch=True,
repeat=False)
示例9: load_fields
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_fields(dataset, data_type, checkpoint):
fields = onmt.io.load_fields_from_vocab(
torch.load(opt.data + '.vocab.pt'), data_type)
fields = dict([(k, f) for (k, f) in fields.items()
if k in dataset.examples[0].__dict__])
if checkpoint is not None:
print('Loading vocab from checkpoint at %s.' % opt.train_from)
fields = onmt.io.load_fields_from_vocab(
checkpoint['vocab'], data_type)
if data_type == 'text':
print(' * vocabulary size. source = %d; target = %d' %
(len(fields['src'].vocab), len(fields['tgt'].vocab)))
else:
print(' * vocabulary size. target = %d' %
(len(fields['tgt'].vocab)))
return fields
示例10: dataset_build
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def dataset_build(self, opt):
fields = onmt.io.get_fields("text", 0, 0)
if hasattr(opt, 'src_vocab') and len(opt.src_vocab) > 0:
with codecs.open(opt.src_vocab, 'w', 'utf-8') as f:
f.write('a\nb\nc\nd\ne\nf\n')
if hasattr(opt, 'tgt_vocab') and len(opt.tgt_vocab) > 0:
with codecs.open(opt.tgt_vocab, 'w', 'utf-8') as f:
f.write('a\nb\nc\nd\ne\nf\n')
train_data_files = preprocess.build_save_dataset('train', fields, opt)
preprocess.build_save_vocab(train_data_files, fields, opt)
preprocess.build_save_dataset('valid', fields, opt)
# Remove the generated *pt files.
for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'):
os.remove(pt)
if hasattr(opt, 'src_vocab') and os.path.exists(opt.src_vocab):
os.remove(opt.src_vocab)
if hasattr(opt, 'tgt_vocab') and os.path.exists(opt.tgt_vocab):
os.remove(opt.tgt_vocab)
示例11: load_fields
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_fields(dataset, data_type, checkpoint):
if checkpoint is not None:
print('Loading vocab from checkpoint at %s.' % opt.train_from)
fields = onmt.io.load_fields_from_vocab(
checkpoint['vocab'], data_type)
else:
fields = onmt.io.load_fields_from_vocab(
torch.load(opt.data + '.vocab.pt'), data_type)
fields = dict([(k, f) for (k, f) in fields.items()
if k in dataset.examples[0].__dict__])
if data_type == 'text':
print(' * vocabulary size. source = %d; target = %d' %
(len(fields['src'].vocab), len(fields['tgt'].vocab)))
else:
print(' * vocabulary size. target = %d' %
(len(fields['tgt'].vocab)))
return fields
示例12: load_fields
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def load_fields(dataset, data_type, checkpoint):
if checkpoint is not None:
print('Loading vocab from checkpoint at %s.' % opt.train_from)
fields = onmt.io.load_fields_from_vocab(
checkpoint['vocab'], data_type)
else:
fields = onmt.io.load_fields_from_vocab(
torch.load(opt.data + '.vocab.pt'), data_type)
# print(dataset.examples[0].__dict__)
# print([(k, f) for (k, f) in fields.items()])
fields = dict([(k, f) for (k, f) in fields.items()
if k in dataset.examples[0].__dict__])
if data_type == 'text' or data_type == 'gcn':
print(' * vocabulary size. source = %d; target = %d' %
(len(fields['src'].vocab), len(fields['tgt'].vocab)))
else:
print(' * vocabulary size. target = %d' %
(len(fields['tgt'].vocab)))
return fields
示例13: forward
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def forward(self, hidden, attn, src_map):
"""
Compute a distribution over the target dictionary
extended by the dynamic dictionary implied by compying
source words.
Args:
hidden (`FloatTensor`): hidden outputs `[batch*tlen, input_size]`
attn (`FloatTensor`): attn for each `[batch*tlen, input_size]`
src_map (`FloatTensor`):
A sparse indicator matrix mapping each source word to
its index in the "extended" vocab containing.
`[src_len, batch, extra_words]`
"""
# CHECKS
batch_by_tlen, _ = hidden.size()
batch_by_tlen_, slen = attn.size()
slen_, batch, cvocab = src_map.size()
aeq(batch_by_tlen, batch_by_tlen_)
aeq(slen, slen_)
# Original probabilities.
logits = self.linear(hidden)
logits[:, self.tgt_dict.stoi[onmt.io.PAD_WORD]] = -float('inf')
prob = F.softmax(logits)
# Probability of copying p(z=1) batch.
p_copy = F.sigmoid(self.linear_copy(hidden))
# Probibility of not copying: p_{word}(w) * (1 - p(z))
out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob))
mul_attn = torch.mul(attn, p_copy.expand_as(attn))
copy_prob = torch.bmm(mul_attn.view(-1, batch, slen)
.transpose(0, 1),
src_map.transpose(0, 1)).transpose(0, 1)
copy_prob = copy_prob.contiguous().view(-1, cvocab)
return torch.cat([out_prob, copy_prob], 1)
示例14: __init__
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def __init__(self, generator, tgt_vocab):
super(LossComputeBase, self).__init__()
self.generator = generator
self.tgt_vocab = tgt_vocab
self.padding_idx = tgt_vocab.stoi[onmt.io.PAD_WORD]
示例15: make_embeddings
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import io [as 别名]
def make_embeddings(opt, word_dict, feature_dicts, for_encoder=True):
"""
Make an Embeddings instance.
Args:
opt: the option in current environment.
word_dict(Vocab): words dictionary.
feature_dicts([Vocab], optional): a list of feature dictionary.
for_encoder(bool): make Embeddings for encoder or decoder?
"""
if for_encoder:
embedding_dim = opt.src_word_vec_size
else:
embedding_dim = opt.tgt_word_vec_size
word_padding_idx = word_dict.stoi[onmt.io.PAD_WORD]
num_word_embeddings = len(word_dict)
feats_padding_idx = [feat_dict.stoi[onmt.io.PAD_WORD]
for feat_dict in feature_dicts]
num_feat_embeddings = [len(feat_dict) for feat_dict in
feature_dicts]
return Embeddings(word_vec_size=embedding_dim,
position_encoding=opt.position_encoding,
feat_merge=opt.feat_merge,
feat_vec_exponent=opt.feat_vec_exponent,
feat_vec_size=opt.feat_vec_size,
dropout=opt.dropout,
word_padding_idx=word_padding_idx,
feat_padding_idx=feats_padding_idx,
word_vocab_size=num_word_embeddings,
feat_vocab_sizes=num_feat_embeddings,
sparse=opt.optim == "sparseadam")