本文整理匯總了Python中run.train方法的典型用法代碼示例。如果您正苦於以下問題:Python run.train方法的具體用法?Python run.train怎麽用?Python run.train使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類run
的用法示例。
在下文中一共展示了run.train方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: train_network
# 需要導入模塊: import run [as 別名]
# 或者: from run import train [as 別名]
def train_network(start_epoch, epochs, optim, model, train_loader, val_loader, criterion, mixup, device, dtype,
batch_size, log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5, best_test, local_rank,
child):
my_range = range if child else trange
for epoch in my_range(start_epoch, epochs + 1):
train_loss, train_accuracy1, train_accuracy5, = train(model, train_loader, mixup, epoch, optim, criterion,
device, dtype, batch_size, log_interval, child)
test_loss, test_accuracy1, test_accuracy5 = test(model, val_loader, criterion, device, dtype, child)
optim.epoch_step()
csv_logger.write({'epoch': epoch + 1, 'val_error1': 1 - test_accuracy1, 'val_error5': 1 - test_accuracy5,
'val_loss': test_loss, 'train_error1': 1 - train_accuracy1,
'train_error5': 1 - train_accuracy5, 'train_loss': train_loss})
save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_test,
'optimizer': optim.state_dict()}, test_accuracy1 > best_test, filepath=save_path,
local_rank=local_rank)
# TODO: save on the end of the cycle
csv_logger.plot_progress(claimed_acc1=claimed_acc1, claimed_acc5=claimed_acc5)
if test_accuracy1 > best_test:
best_test = test_accuracy1
csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test * 100.))
示例2: train_network
# 需要導入模塊: import run [as 別名]
# 或者: from run import train [as 別名]
def train_network(start_epoch, epochs, scheduler, model, train_loader, val_loader, optimizer, criterion, device, dtype,
batch_size, log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5, best_test):
for epoch in trange(start_epoch, epochs + 1):
if not isinstance(scheduler, CyclicLR):
scheduler.step()
train_loss, train_accuracy1, train_accuracy5, = train(model, train_loader, epoch, optimizer, criterion, device,
dtype, batch_size, log_interval, scheduler)
test_loss, test_accuracy1, test_accuracy5 = test(model, val_loader, criterion, device, dtype)
csv_logger.write({'epoch': epoch + 1, 'val_error1': 1 - test_accuracy1, 'val_error5': 1 - test_accuracy5,
'val_loss': test_loss, 'train_error1': 1 - train_accuracy1,
'train_error5': 1 - train_accuracy5, 'train_loss': train_loss})
save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_test,
'optimizer': optimizer.state_dict()}, test_accuracy1 > best_test, filepath=save_path)
csv_logger.plot_progress(claimed_acc1=claimed_acc1, claimed_acc5=claimed_acc5)
if test_accuracy1 > best_test:
best_test = test_accuracy1
csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test * 100.))