本文整理汇总了Python中torch.autograd.Variable.volatile方法的典型用法代码示例。如果您正苦于以下问题:Python Variable.volatile方法的具体用法?Python Variable.volatile怎么用?Python Variable.volatile使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.autograd.Variable
的用法示例。
在下文中一共展示了Variable.volatile方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: accumulate_gradient
# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import volatile [as 别名]
def accumulate_gradient(self, batch_sz, states, actions, rewards,
next_states, mask):
""" Compute the temporal difference error.
td_error = (r + gamma * max Q(s_,a)) - Q(s,a)
"""
states = Variable(states)
actions = Variable(actions)
rewards = Variable(rewards)
next_states = Variable(next_states, volatile=True)
# Compute Q(s, a)
q_values = self.policy(states)
q_values = q_values.gather(1, actions.unsqueeze(1))
# Compute Q(s_, a)
q_target_values = None
if next_states.is_cuda:
q_target_values = Variable(torch.zeros(batch_sz).cuda())
else:
q_target_values = Variable(torch.zeros(batch_sz))
# Bootstrap for non-terminal states
q_target_values[mask] = self.target_policy(next_states).max(1)[0][mask]
q_target_values.volatile = False # So we don't mess the huber loss
expected_q_values = (q_target_values * self.gamma) + rewards
# Compute Huber loss
loss = F.smooth_l1_loss(q_values, expected_q_values)
# Accumulate gradients
loss.backward()
示例2: optimize_model
# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import volatile [as 别名]
def optimize_model():
global last_sync
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
# Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
# detailed explanation).
batch = Transition(*zip(*transitions))
# Compute a mask of non-final states and concatenate the batch elements
non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
batch.next_state)))
# We don't want to backprop through the expected action values and volatile
# will save us on temporarily changing the model parameters'
# requires_grad to False!
non_final_next_states = Variable(torch.cat([s for s in batch.next_state
if s is not None]),
volatile=True)
state_batch = Variable(torch.cat(batch.state))
action_batch = Variable(torch.cat(batch.action))
reward_batch = Variable(torch.cat(batch.reward))
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken
state_action_values = model(state_batch).gather(1, action_batch)
# Compute V(s_{t+1}) for all next states.
next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))
next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]
# Now, we don't want to mess up the loss with a volatile flag, so let's
# clear it. After this, we'll just end up with a Variable that has
# requires_grad=False
next_state_values.volatile = False
# Compute the expected Q values
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
# Compute Huber loss
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
# Optimize the model
optimizer.zero_grad()
loss.backward()
for param in model.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
示例3: test_volatile_fallback
# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import volatile [as 别名]
def test_volatile_fallback(self):
"""Check that Traceable falls back to num_backwards=0 if given volatile inputs"""
x = Variable(torch.randn(2, 2))
y = Variable(torch.randn(2, 2), requires_grad=True)
@torch.jit.compile
def fn(x, y):
return x * x + x * y
out = fn(x, y)
self.assertFalse(fn.has_trace_for(x, y))
x.volatile = True
self.assertFalse(fn.has_trace_for(x, y))
out = fn(x, y)
self.assertTrue(fn.has_trace_for(x, y))
with self.assertCompiled(fn):
out2 = fn(x, y)
self.assertEqual(out, out2)
示例4: iter
# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import volatile [as 别名]
if ctr % plot_every == 0:
plot_loss_avg = plot_loss_total / plot_every
plot_loss_total = 0
plot_losses.append(plot_loss_avg)
val_loss_total = 0 # Validation/early stopping
model.eval()
if (not args.freeze_models) and args.interpolated_model:
for member in model.members:
member.eval()
for batch in iter(val_iter):
x_de = batch.src.transpose(1,0).cuda()
x_en = batch.trg.transpose(1,0).cuda()
if model_type == 1:
x_de = flip(x_de,1) # reverse direction
x_de.volatile = True # "inference mode" supposedly speeds up
loss, reinforce_loss, avg_reward, _ = model.forward(x_de, x_en)
# too lazy to implement reward or accuracy for validation
val_loss_total -= avg_reward
val_loss_avg = val_loss_total / len(val_iter)
timenow = timeSince(start)
current_ppl = np.exp(val_loss_avg)
print('Validation. Time %s, PPL: %.2f' %(timenow, current_ppl))
if args.frequent_ckpt:
torch.save(model.state_dict(), args.model_file) # I'm Paranoid!!!!!!!!!!!!!!!!
print("Saved Checkpoint")
elif args.save_best and (current_ppl < best_ppl):
torch.save(model.state_dict(), args.model_file)
print("Saved Checkpoint")
best_ppl = current_ppl