本文整理汇总了Python中misc.utils.clip_gradient方法的典型用法代码示例。如果您正苦于以下问题:Python utils.clip_gradient方法的具体用法?Python utils.clip_gradient怎么用?Python utils.clip_gradient使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类misc.utils
的用法示例。
在下文中一共展示了utils.clip_gradient方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from misc import utils [as 别名]
# 或者: from misc.utils import clip_gradient [as 别名]
def train(loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None):
model.train()
model = nn.DataParallel(model)
for epoch in range(opt["epochs"]):
lr_scheduler.step()
iteration = 0
# If start self crit training
if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
sc_flag = True
init_cider_scorer(opt["cached_tokens"])
else:
sc_flag = False
for data in loader:
torch.cuda.synchronize()
fc_feats = Variable(data['fc_feats']).cuda()
labels = Variable(data['labels']).long().cuda()
masks = Variable(data['masks']).cuda()
optimizer.zero_grad()
if not sc_flag:
seq_probs, _ = model(fc_feats, labels, 'train')
loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
else:
seq_probs, seq_preds = model(
fc_feats, mode='inference', opt=opt)
reward = get_self_critical_reward(model, fc_feats, data,
seq_preds)
print(reward.shape)
loss = rl_crit(seq_probs, seq_preds,
Variable(
torch.from_numpy(reward).float().cuda()))
loss.backward()
utils.clip_gradient(optimizer, opt["grad_clip"])
optimizer.step()
train_loss = loss.data[0]
torch.cuda.synchronize()
iteration += 1
if not sc_flag:
print("iter %d (epoch %d), train_loss = %.6f" %
(iteration, epoch, train_loss))
else:
print("iter %d (epoch %d), avg_reward = %.6f" %
(iteration, epoch, np.mean(reward[:, 0])))
if epoch != 0 and epoch % opt["save_checkpoint_every"] == 0:
model_path = os.path.join(opt["checkpoint_path"],
'model_%d.pth' % (epoch))
model_info_path = os.path.join(opt["checkpoint_path"],
'model_score.txt')
torch.save(model.state_dict(), model_path)
print("model saved to %s" % (model_path))
with open(model_info_path, 'a') as f:
f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))