本文整理汇总了Python中utils.utils.save_checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python utils.save_checkpoint方法的具体用法?Python utils.save_checkpoint怎么用?Python utils.save_checkpoint使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类utils.utils
的用法示例。
在下文中一共展示了utils.save_checkpoint方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: save_checkpoint
# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import save_checkpoint [as 别名]
def save_checkpoint(self, epoch, best_prec, is_best):
state = {
'epoch': epoch,
'best_prec': best_prec,
'state_dict': self.model.state_dict(),
'optim' : self.optim.state_dict(),
}
utils.save_checkpoint(state, is_best, dirpath=self.dirpath, filename='model_checkpoint.pkl')
if(is_best):
path_save = os.path.join(self.dirpath, 'weight_best.pkl')
torch.save({'state_dict': self.model.state_dict()}, path_save)
示例2: start
# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import save_checkpoint [as 别名]
def start(self):
args = self.args
if args.mode == 'test':
self.validate()
return
losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = [], [], [], [], [], [], []
path_val = os.path.join(self.dirpath, "loss.pkl")
if(os.path.exists(path_val)):
state_val = torch.load(path_val)
losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = state_val
# 开始训练模型
plt.figure(figsize=(18, 5))
time_start = time.time()
epoch0 = self.epoch
for epoch in range(epoch0, args.epochs):
self.epoch = epoch
self.lr_adjust(self.optim, args.lr_epoch0, args.lr_stride, args.lr, epoch) # 自定义的lr_adjust函数,见上
self.lossfun.Weight_Adjust_levels(epoch)
msg = 'lr: %.6f | weight of levels: %s' % (self.optim.param_groups[0]['lr'], str(self.lossfun.weight_levels))
logging.info(msg)
# train for one epoch
mloss, mEPE, mD1 = self.train()
losses.append(mloss)
EPEs.append(mEPE)
D1s.append(mD1)
if(epoch % self.args.val_freq == 0) or (epoch == args.epochs-1):
# evaluate on validation set
mloss_val, mEPE_val, mD1_val = self.validate()
epochs_val.append(epoch)
losses_val.append(mloss_val)
EPEs_val.append(mEPE_val)
D1s_val.append(mD1_val)
# remember best prec@1 and save checkpoint
is_best = mD1_val < self.best_prec
self.best_prec = min(mD1_val, self.best_prec)
self.save_checkpoint(epoch, self.best_prec, is_best)
torch.save([losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val], path_val)
# plt
m, n = 1, 3
ax1 = plt.subplot(m, n, 1)
ax2 = plt.subplot(m, n, 2)
ax3 = plt.subplot(m, n, 3)
plt.sca(ax1); plt.cla(); plt.xlabel("epoch"); plt.ylabel("Loss")
plt.plot(np.array(losses), label='train'); plt.plot(np.array(epochs_val), np.array(losses_val), label='val'); plt.legend()
plt.sca(ax2); plt.cla(); plt.xlabel("epoch"); plt.ylabel("EPE")
plt.plot(np.array(EPEs), label='train'); plt.plot(np.array(epochs_val), np.array(EPEs_val), label='val'); plt.legend()
plt.sca(ax3); plt.cla(); plt.xlabel("epoch"); plt.ylabel("D1")
plt.plot(np.array(D1s), label='train'); plt.plot(np.array(epochs_val), np.array(D1s_val), label='val'); plt.legend()
plt.savefig("check_%s_%s_%s_%s.png" % (args.mode, args.dataset, args.net, args.loss_name))
time_curr = (time.time() - time_start)/3600.0
time_all = time_curr*(args.epochs - epoch0)/(epoch + 1 - epoch0)
msg = 'Progress: %.2f | %.2f (hour)\n' % (time_curr, time_all)
logging.info(msg)
示例3: run
# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import save_checkpoint [as 别名]
def run(self, epochs, train_loader, val_loader, test_loader, log_interval):
cuda = self.device != -1
with torch.cuda.device(self.device):
trainer = create_supervised_trainer(self.model, self.optimizer, self.loss_fn, cuda=cuda)
evaluator = create_supervised_evaluator(self.model, metrics=self.metrics, y_to_score=self.y_to_score, pred_to_score=self.pred_to_score, cuda=cuda)
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iteration = (engine.state.iteration - 1) % len(train_loader) + 1
if iteration % log_interval == 0:
print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
"".format(engine.state.epoch, iteration, len(train_loader), engine.state.output))
self.writer.add_scalar("train/loss", engine.state.output, engine.state.iteration)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
state_metrics = evaluator.state.metrics
state_metric_keys = list(self.metrics.keys())
state_metric_vals = [state_metrics[k] for k in state_metric_keys]
format_str = 'Validation Results - Epoch: {} ' + ' '.join([k + ': {:.4f}' for k in state_metric_keys])
print(format_str.format(*([engine.state.epoch] + state_metric_vals)))
for i, k in enumerate(state_metric_keys):
self.writer.add_scalar(f'dev/{k}', state_metric_vals[i], engine.state.epoch)
if state_metric_vals[0] > self.best_score:
state_dict = {
'epoch': engine.state.epoch,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'eval_metric': state_metric_vals[0]
}
utils.save_checkpoint(state_dict, self.model_id)
self.best_score = state_metric_vals[0]
@trainer.on(Events.COMPLETED)
def log_test_results(engine):
checkpoint = torch.load(self.model_id)
self.model.load_state_dict(checkpoint['state_dict'])
evaluator.run(test_loader)
state_metrics = evaluator.state.metrics
state_metric_keys = list(self.metrics.keys())
state_metric_vals = [state_metrics[k] for k in state_metric_keys]
format_str = 'Test Results - Epoch: {} ' + ' '.join([k + ': {:.4f}' for k in state_metric_keys])
print(format_str.format(*([engine.state.epoch] + state_metric_vals)))
for i, k in enumerate(state_metric_keys):
self.writer.add_scalar(f'test/{k}', state_metric_vals[i], engine.state.epoch)
trainer.run(train_loader, max_epochs=epochs)
self.writer.close()