本文整理汇总了Python中onmt.Trainer方法的典型用法代码示例。如果您正苦于以下问题:Python onmt.Trainer方法的具体用法?Python onmt.Trainer怎么用?Python onmt.Trainer使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类onmt
的用法示例。
在下文中一共展示了onmt.Trainer方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: evaluate_model
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Trainer [as 别名]
def evaluate_model(model, test_loader, copy_attn=False, copy_attn_force=None):
""" Copied from the method in onmt.Trainer """
# Set model in validating mode.
model.eval()
stats = onmt.Statistics()
valid_loss = make_loss_compute(model, test_loader.dataset.fields["tgt"].vocab, test_loader.dataset,
copy_attn=copy_attn, copy_attn_force=copy_attn_force)
for batch in test_loader:
_, src_lengths = batch.src
src = onmt.IO.make_features(batch, 'src')
tgt = onmt.IO.make_features(batch, 'tgt')
# F-prop through the model.
outputs, attns, _ = model(src, tgt, src_lengths)
# Compute loss.
gen_state = onmt.Loss.make_gen_state(
outputs, batch, attns, (0, batch.tgt.size(0)))
_, batch_stats = valid_loss(batch, **gen_state)
# Update statistics.
stats.update(batch_stats)
# Set model back to training mode.
model.train()
return stats
示例2: train_model
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Trainer [as 别名]
def train_model(model, fields, optim, data_type, model_opt):
train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt,
train=False)
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches
norm_method = opt.normalization
grad_accum_count = opt.accum_count
trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
trunc_size, shard_size, data_type,
norm_method, grad_accum_count)
print('\nStart training...')
print(' * number of epochs: %d, starting from Epoch %d' %
(opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
print(' * batch size: %d' % opt.batch_size)
for epoch in range(opt.start_epoch, opt.epochs + 1):
print('')
# 1. Train for one epoch on the training set.
train_iter = make_dataset_iter(lazily_load_dataset("train"),
fields, opt)
train_stats = trainer.train(train_iter, epoch, report_func)
print('Train perplexity: %g' % train_stats.ppl())
print('Train accuracy: %g' % train_stats.accuracy())
# 2. Validate on the validation set.
valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
fields, opt,
is_train=False)
valid_stats = trainer.validate(valid_iter)
print('Validation perplexity: %g' % valid_stats.ppl())
print('Validation accuracy: %g' % valid_stats.accuracy())
# 3. Log to remote server.
if opt.exp_host:
train_stats.log("train", experiment, optim.lr)
valid_stats.log("valid", experiment, optim.lr)
if opt.tensorboard:
train_stats.log_tensorboard("train", writer, optim.lr, epoch)
train_stats.log_tensorboard("valid", writer, optim.lr, epoch)
# 4. Update the learning rate
trainer.epoch_step(valid_stats.ppl(), epoch)
# 5. Drop a checkpoint if needed.
if epoch >= opt.start_checkpoint_at:
trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
示例3: train_model
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Trainer [as 别名]
def train_model(model, fields, optim, data_type, model_opt):
train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches
trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
trunc_size, shard_size, data_type,
opt.normalization, opt.accum_count)
for epoch in range(opt.start_epoch, opt.epochs + 1):
print('')
# 1. Train for one epoch on the training set.
train_datasets = lazily_load_dataset("train")
train_iter = make_dataset_iter(train_datasets, fields, opt)
train_stats = trainer.train(train_iter, epoch, report_func)
print('Train perplexity: %g' % train_stats.ppl())
print('Train accuracy: %g' % train_stats.accuracy())
# 2. Validate on the validation set.
valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
fields, opt,
is_train=False)
valid_stats = trainer.validate(valid_iter)
print('Validation perplexity: %g' % valid_stats.ppl())
print('Validation accuracy: %g' % valid_stats.accuracy())
# 3. Log to remote server.
if opt.exp_host:
train_stats.log("train", experiment, optim.lr)
valid_stats.log("valid", experiment, optim.lr)
# 4. Update the learning rate
trainer.epoch_step(valid_stats.ppl(), epoch)
# 5. Drop a checkpoint if needed.
if epoch >= opt.start_checkpoint_at:
trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
示例4: train_model
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Trainer [as 别名]
def train_model(model, fields, optim, data_type, model_opt):
train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt,
train=False)
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches
norm_method = opt.normalization
grad_accum_count = opt.accum_count
trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
trunc_size, shard_size, data_type,
norm_method, grad_accum_count)
print('\nStart training...')
print(' * number of epochs: %d, starting from Epoch %d' %
(opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
print(' * batch size: %d' % opt.batch_size)
for epoch in range(opt.start_epoch, opt.epochs + 1):
print('')
# 1. Train for one epoch on the training set.
lld = lazily_load_dataset("train")
train_iter = make_dataset_iter(lld, fields, opt)
train_stats = trainer.train(train_iter, epoch, report_func)
print('Train perplexity: %g' % train_stats.ppl())
print('Train accuracy: %g' % train_stats.accuracy())
# 2. Validate on the validation set.
valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
fields, opt,
is_train=False)
valid_stats = trainer.validate(valid_iter)
print('Validation perplexity: %g' % valid_stats.ppl())
print('Validation accuracy: %g' % valid_stats.accuracy())
# 3. Log to remote server.
if opt.exp_host:
train_stats.log("train", experiment, optim.lr)
valid_stats.log("valid", experiment, optim.lr)
if opt.tensorboard:
train_stats.log_tensorboard("train", writer, optim.lr, epoch)
train_stats.log_tensorboard("valid", writer, optim.lr, epoch)
# 4. Update the learning rate
trainer.epoch_step(valid_stats.ppl(), epoch)
# 5. Drop a checkpoint if needed.
if epoch >= opt.start_checkpoint_at:
trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
示例5: train_model
# 需要导入模块: import onmt [as 别名]
# 或者: from onmt import Trainer [as 别名]
def train_model(model, train_dataset, valid_dataset,
fields, optim, model_opt):
train_iter = make_train_data_iter(train_dataset, opt)
valid_iter = make_valid_data_iter(valid_dataset, opt)
train_loss = make_loss_compute(model, fields["tgt"].vocab,
train_dataset, opt)
valid_loss = make_loss_compute(model, fields["tgt"].vocab,
valid_dataset, opt)
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches
data_type = train_dataset.data_type
trainer = onmt.Trainer(model, train_iter, valid_iter,
train_loss, valid_loss, optim,
trunc_size, shard_size, data_type)
for epoch in range(opt.start_epoch, opt.epochs + 1):
print('')
# 1. Train for one epoch on the training set.
train_stats = trainer.train(epoch, report_func)
print('Train perplexity: %g' % train_stats.ppl())
print('Train accuracy: %g' % train_stats.accuracy())
# 2. Validate on the validation set.
valid_stats = trainer.validate()
print('Validation perplexity: %g' % valid_stats.ppl())
print('Validation accuracy: %g' % valid_stats.accuracy())
# 3. Log to remote server.
if opt.exp_host:
train_stats.log("train", experiment, optim.lr)
valid_stats.log("valid", experiment, optim.lr)
# 4. Update the learning rate
trainer.epoch_step(valid_stats.ppl(), epoch)
# 5. Drop a checkpoint if needed.
if epoch >= opt.start_checkpoint_at:
trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)