當前位置: 首頁>>代碼示例>>Python>>正文


Python onmt.Trainer方法代碼示例

本文整理匯總了Python中onmt.Trainer方法的典型用法代碼示例。如果您正苦於以下問題:Python onmt.Trainer方法的具體用法?Python onmt.Trainer怎麽用?Python onmt.Trainer使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在onmt的用法示例。


在下文中一共展示了onmt.Trainer方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: evaluate_model

# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def evaluate_model(model, test_loader, copy_attn=False, copy_attn_force=None):

    """ Copied from the method in onmt.Trainer """

    # Set model in validating mode.
    model.eval()

    stats = onmt.Statistics()
    valid_loss = make_loss_compute(model, test_loader.dataset.fields["tgt"].vocab, test_loader.dataset,
                                   copy_attn=copy_attn, copy_attn_force=copy_attn_force)

    for batch in test_loader:
        _, src_lengths = batch.src
        src = onmt.IO.make_features(batch, 'src')
        tgt = onmt.IO.make_features(batch, 'tgt')

        # F-prop through the model.
        outputs, attns, _ = model(src, tgt, src_lengths)

        # Compute loss.
        gen_state = onmt.Loss.make_gen_state(
            outputs, batch, attns, (0, batch.tgt.size(0)))
        _, batch_stats = valid_loss(batch, **gen_state)

        # Update statistics.
        stats.update(batch_stats)

    # Set model back to training mode.
    model.train()
    return stats 
開發者ID:antspy,項目名稱:quantized_distillation,代碼行數:32,代碼來源:model.py

示例2: train_model

# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def train_model(model, fields, optim, data_type, model_opt):
    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt,
                                   train=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
                           trunc_size, shard_size, data_type,
                           norm_method, grad_accum_count)

    print('\nStart training...')
    print(' * number of epochs: %d, starting from Epoch %d' %
          (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    print(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_iter = make_dataset_iter(lazily_load_dataset("train"),
                                       fields, opt)
        train_stats = trainer.train(train_iter, epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields, opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)
        if opt.tensorboard:
            train_stats.log_tensorboard("train", writer, optim.lr, epoch)
            train_stats.log_tensorboard("valid", writer, optim.lr, epoch)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats) 
開發者ID:xiadingZ,項目名稱:video-caption-openNMT.pytorch,代碼行數:53,代碼來源:train.py

示例3: train_model

# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def train_model(model, fields, optim, data_type, model_opt):

    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches

    trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
                           trunc_size, shard_size, data_type,
                           opt.normalization, opt.accum_count)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_datasets = lazily_load_dataset("train")
        train_iter = make_dataset_iter(train_datasets, fields, opt)
        train_stats = trainer.train(train_iter, epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields, opt,
                                       is_train=False)

        valid_stats = trainer.validate(valid_iter)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats) 
開發者ID:abaheti95,項目名稱:DC-NeuralConversation,代碼行數:44,代碼來源:train.py

示例4: train_model

# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def train_model(model, fields, optim, data_type, model_opt):
    train_loss = make_loss_compute(model, fields["tgt"].vocab, opt)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab, opt,
                                   train=False)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    norm_method = opt.normalization
    grad_accum_count = opt.accum_count

    trainer = onmt.Trainer(model, train_loss, valid_loss, optim,
                           trunc_size, shard_size, data_type,
                           norm_method, grad_accum_count)

    print('\nStart training...')
    print(' * number of epochs: %d, starting from Epoch %d' %
          (opt.epochs + 1 - opt.start_epoch, opt.start_epoch))
    print(' * batch size: %d' % opt.batch_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        lld = lazily_load_dataset("train")
        train_iter = make_dataset_iter(lld, fields, opt)
        train_stats = trainer.train(train_iter, epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_iter = make_dataset_iter(lazily_load_dataset("valid"),
                                       fields, opt,
                                       is_train=False)
        valid_stats = trainer.validate(valid_iter)
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)
        if opt.tensorboard:
            train_stats.log_tensorboard("train", writer, optim.lr, epoch)
            train_stats.log_tensorboard("valid", writer, optim.lr, epoch)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats) 
開發者ID:diegma,項目名稱:graph-2-text,代碼行數:53,代碼來源:train.py

示例5: train_model

# 需要導入模塊: import onmt [as 別名]
# 或者: from onmt import Trainer [as 別名]
def train_model(model, train_dataset, valid_dataset,
                fields, optim, model_opt):

    train_iter = make_train_data_iter(train_dataset, opt)
    valid_iter = make_valid_data_iter(valid_dataset, opt)

    train_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   train_dataset, opt)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   valid_dataset, opt)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    data_type = train_dataset.data_type

    trainer = onmt.Trainer(model, train_iter, valid_iter,
                           train_loss, valid_loss, optim,
                           trunc_size, shard_size, data_type)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats = trainer.train(epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats) 
開發者ID:moonlightlane,項目名稱:QG-Net,代碼行數:45,代碼來源:train.py


注:本文中的onmt.Trainer方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。