本文整理匯總了Python中onmt.Trainer方法的典型用法代碼示例。如果您正苦於以下問題:Python onmt.Trainer方法的具體用法?Python onmt.Trainer怎麽用?Python onmt.Trainer使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類onmt
的用法示例。
在下文中一共展示了onmt.Trainer方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: evaluate_model
# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def evaluate_model(model, test_loader, copy_attn=False, copy_attn_force=None):
""" Copied from the method in onmt.Trainer """
# Set model in validating mode.
model.eval()
stats = onmt.Statistics()
valid_loss = make_loss_compute(model, test_loader.dataset.fields["tgt"].vocab, test_loader.dataset,
copy_attn=copy_attn, copy_attn_force=copy_attn_force)
for batch in test_loader:
_, src_lengths = batch.src
src = onmt.IO.make_features(batch, 'src')
tgt = onmt.IO.make_features(batch, 'tgt')
# F-prop through the model.
outputs, attns, _ = model(src, tgt, src_lengths)
# Compute loss.
gen_state = onmt.Loss.make_gen_state(
outputs, batch, attns, (0, batch.tgt.size(0)))
_, batch_stats = valid_loss(batch, **gen_state)
# Update statistics.
stats.update(batch_stats)
# Set model back to training mode.
model.train()
return stats
示例2: train_model
# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def train_model(model, fields, optim, data_type, model_opt):
train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt,
train=False)
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches
norm_method = opt.normalization
grad_accum_count = opt.accum_count
trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
trunc_size, shard_size, data_type,
norm_method, grad_accum_count)
print('\nStart training...')
print(' * number of epochs: %d, starting from Epoch %d' %
(opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
print(' * batch size: %d' % opt.batch_size)
for epoch in range(opt.start_epoch, opt.epochs + 1):
print('')
# 1. Train for one epoch on the training set.
train_iter = make_dataset_iter(lazily_load_dataset("train"),
fields, opt)
train_stats = trainer.train(train_iter, epoch, report_func)
print('Train perplexity: %g' % train_stats.ppl())
print('Train accuracy: %g' % train_stats.accuracy())
# 2. Validate on the validation set.
valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
fields, opt,
is_train=False)
valid_stats = trainer.validate(valid_iter)
print('Validation perplexity: %g' % valid_stats.ppl())
print('Validation accuracy: %g' % valid_stats.accuracy())
# 3. Log to remote server.
if opt.exp_host:
train_stats.log("train", experiment, optim.lr)
valid_stats.log("valid", experiment, optim.lr)
if opt.tensorboard:
train_stats.log_tensorboard("train", writer, optim.lr, epoch)
train_stats.log_tensorboard("valid", writer, optim.lr, epoch)
# 4. Update the learning rate
trainer.epoch_step(valid_stats.ppl(), epoch)
# 5. Drop a checkpoint if needed.
if epoch >= opt.start_checkpoint_at:
trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
示例3: train_model
# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def train_model(model, fields, optim, data_type, model_opt):
train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches
trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
trunc_size, shard_size, data_type,
opt.normalization, opt.accum_count)
for epoch in range(opt.start_epoch, opt.epochs + 1):
print('')
# 1. Train for one epoch on the training set.
train_datasets = lazily_load_dataset("train")
train_iter = make_dataset_iter(train_datasets, fields, opt)
train_stats = trainer.train(train_iter, epoch, report_func)
print('Train perplexity: %g' % train_stats.ppl())
print('Train accuracy: %g' % train_stats.accuracy())
# 2. Validate on the validation set.
valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
fields, opt,
is_train=False)
valid_stats = trainer.validate(valid_iter)
print('Validation perplexity: %g' % valid_stats.ppl())
print('Validation accuracy: %g' % valid_stats.accuracy())
# 3. Log to remote server.
if opt.exp_host:
train_stats.log("train", experiment, optim.lr)
valid_stats.log("valid", experiment, optim.lr)
# 4. Update the learning rate
trainer.epoch_step(valid_stats.ppl(), epoch)
# 5. Drop a checkpoint if needed.
if epoch >= opt.start_checkpoint_at:
trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
示例4: train_model
# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def train_model(model, fields, optim, data_type, model_opt):
train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt,
train=False)
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches
norm_method = opt.normalization
grad_accum_count = opt.accum_count
trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
trunc_size, shard_size, data_type,
norm_method, grad_accum_count)
print('\nStart training...')
print(' * number of epochs: %d, starting from Epoch %d' %
(opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
print(' * batch size: %d' % opt.batch_size)
for epoch in range(opt.start_epoch, opt.epochs + 1):
print('')
# 1. Train for one epoch on the training set.
lld = lazily_load_dataset("train")
train_iter = make_dataset_iter(lld, fields, opt)
train_stats = trainer.train(train_iter, epoch, report_func)
print('Train perplexity: %g' % train_stats.ppl())
print('Train accuracy: %g' % train_stats.accuracy())
# 2. Validate on the validation set.
valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
fields, opt,
is_train=False)
valid_stats = trainer.validate(valid_iter)
print('Validation perplexity: %g' % valid_stats.ppl())
print('Validation accuracy: %g' % valid_stats.accuracy())
# 3. Log to remote server.
if opt.exp_host:
train_stats.log("train", experiment, optim.lr)
valid_stats.log("valid", experiment, optim.lr)
if opt.tensorboard:
train_stats.log_tensorboard("train", writer, optim.lr, epoch)
train_stats.log_tensorboard("valid", writer, optim.lr, epoch)
# 4. Update the learning rate
trainer.epoch_step(valid_stats.ppl(), epoch)
# 5. Drop a checkpoint if needed.
if epoch >= opt.start_checkpoint_at:
trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
示例5: train_model
# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def train_model(model, train_dataset, valid_dataset,
fields, optim, model_opt):
train_iter = make_train_data_iter(train_dataset, opt)
valid_iter = make_valid_data_iter(valid_dataset, opt)
train_loss = make_loss_compute(model, fields["tgt"].vocab,
train_dataset, opt)
valid_loss = make_loss_compute(model, fields["tgt"].vocab,
valid_dataset, opt)
trunc_size = opt.truncated_decoder # Badly named...
shard_size = opt.max_generator_batches
data_type = train_dataset.data_type
trainer = onmt.Trainer(model, train_iter, valid_iter,
train_loss, valid_loss, optim,
trunc_size, shard_size, data_type)
for epoch in range(opt.start_epoch, opt.epochs + 1):
print('')
# 1. Train for one epoch on the training set.
train_stats = trainer.train(epoch, report_func)
print('Train perplexity: %g' % train_stats.ppl())
print('Train accuracy: %g' % train_stats.accuracy())
# 2. Validate on the validation set.
valid_stats = trainer.validate()
print('Validation perplexity: %g' % valid_stats.ppl())
print('Validation accuracy: %g' % valid_stats.accuracy())
# 3. Log to remote server.
if opt.exp_host:
train_stats.log("train", experiment, optim.lr)
valid_stats.log("valid", experiment, optim.lr)
# 4. Update the learning rate
trainer.epoch_step(valid_stats.ppl(), epoch)
# 5. Drop a checkpoint if needed.
if epoch >= opt.start_checkpoint_at:
trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)