本文整理汇总了Python中run.train方法的典型用法代码示例。如果您正苦于以下问题:Python run.train方法的具体用法?Python run.train怎么用?Python run.train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类run
的用法示例。
在下文中一共展示了run.train方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train_network
# 需要导入模块: import run [as 别名]
# 或者: from run import train [as 别名]
def train_network(start_epoch, epochs, optim, model, train_loader, val_loader, criterion, mixup, device, dtype,
batch_size, log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5, best_test, local_rank,
child):
my_range = range if child else trange
for epoch in my_range(start_epoch, epochs + 1):
train_loss, train_accuracy1, train_accuracy5, = train(model, train_loader, mixup, epoch, optim, criterion,
device, dtype, batch_size, log_interval, child)
test_loss, test_accuracy1, test_accuracy5 = test(model, val_loader, criterion, device, dtype, child)
optim.epoch_step()
csv_logger.write({'epoch': epoch + 1, 'val_error1': 1 - test_accuracy1, 'val_error5': 1 - test_accuracy5,
'val_loss': test_loss, 'train_error1': 1 - train_accuracy1,
'train_error5': 1 - train_accuracy5, 'train_loss': train_loss})
save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_test,
'optimizer': optim.state_dict()}, test_accuracy1 > best_test, filepath=save_path,
local_rank=local_rank)
# TODO: save on the end of the cycle
csv_logger.plot_progress(claimed_acc1=claimed_acc1, claimed_acc5=claimed_acc5)
if test_accuracy1 > best_test:
best_test = test_accuracy1
csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test * 100.))
示例2: train_network
# 需要导入模块: import run [as 别名]
# 或者: from run import train [as 别名]
def train_network(start_epoch, epochs, scheduler, model, train_loader, val_loader, optimizer, criterion, device, dtype,
batch_size, log_interval, csv_logger, save_path, claimed_acc1, claimed_acc5, best_test):
for epoch in trange(start_epoch, epochs + 1):
if not isinstance(scheduler, CyclicLR):
scheduler.step()
train_loss, train_accuracy1, train_accuracy5, = train(model, train_loader, epoch, optimizer, criterion, device,
dtype, batch_size, log_interval, scheduler)
test_loss, test_accuracy1, test_accuracy5 = test(model, val_loader, criterion, device, dtype)
csv_logger.write({'epoch': epoch + 1, 'val_error1': 1 - test_accuracy1, 'val_error5': 1 - test_accuracy5,
'val_loss': test_loss, 'train_error1': 1 - train_accuracy1,
'train_error5': 1 - train_accuracy5, 'train_loss': train_loss})
save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_test,
'optimizer': optimizer.state_dict()}, test_accuracy1 > best_test, filepath=save_path)
csv_logger.plot_progress(claimed_acc1=claimed_acc1, claimed_acc5=claimed_acc5)
if test_accuracy1 > best_test:
best_test = test_accuracy1
csv_logger.write_text('Best accuracy is {:.2f}% top-1'.format(best_test * 100.))