本文整理汇总了Python中torch.autograd.Variable.gather方法的典型用法代码示例。如果您正苦于以下问题:Python Variable.gather方法的具体用法?Python Variable.gather怎么用?Python Variable.gather使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.autograd.Variable
的用法示例。
在下文中一共展示了Variable.gather方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import gather [as 别名]
def train(self):
if self.T - self.target_sync_T > self.args.target:
self.sync_target_network()
self.target_sync_T = self.T
info = {}
for _ in range(self.args.iters):
self.dqn.eval()
# TODO: Use a named tuple for experience replay
n_step_sample = self.args.n_step
batch, indices, is_weights = self.replay.Sample_N(self.args.batch_size, n_step_sample, self.args.gamma)
columns = list(zip(*batch))
states = Variable(torch.from_numpy(np.array(columns[0])).float().transpose_(1, 3))
actions = Variable(torch.LongTensor(columns[1]))
terminal_states = Variable(torch.FloatTensor(columns[5]))
rewards = Variable(torch.FloatTensor(columns[2]))
# Have to clip rewards for DQN
rewards = torch.clamp(rewards, -1, 1)
steps = Variable(torch.FloatTensor(columns[4]))
new_states = Variable(torch.from_numpy(np.array(columns[3])).float().transpose_(1, 3))
target_dqn_qvals = self.target_dqn(new_states).cpu()
# Make a new variable with those values so that these are treated as constants
target_dqn_qvals_data = Variable(target_dqn_qvals.data)
q_value_targets = (Variable(torch.ones(terminal_states.size()[0])) - terminal_states)
inter = Variable(torch.ones(terminal_states.size()[0]) * self.args.gamma)
# print(steps)
q_value_targets = q_value_targets * torch.pow(inter, steps)
if self.args.double:
# Double Q Learning
new_states_qvals = self.dqn(new_states).cpu()
new_states_qvals_data = Variable(new_states_qvals.data)
q_value_targets = q_value_targets * target_dqn_qvals_data.gather(1, new_states_qvals_data.max(1)[1])
else:
q_value_targets = q_value_targets * target_dqn_qvals_data.max(1)[0]
q_value_targets = q_value_targets + rewards
self.dqn.train()
one_hot_actions = torch.zeros(self.args.batch_size, self.args.actions)
for i in range(self.args.batch_size):
one_hot_actions[i][actions[i].data] = 1
if self.args.gpu:
actions = actions.cuda()
one_hot_actions = one_hot_actions.cuda()
q_value_targets = q_value_targets.cuda()
new_states = new_states.cuda()
model_predictions_q_vals, model_predictions_state = self.dqn(states, Variable(one_hot_actions))
model_predictions = model_predictions_q_vals.gather(1, actions.view(-1, 1))
# info = {}
td_error = model_predictions - q_value_targets
info["TD_Error"] = td_error.mean().data[0]
# Update the priorities
if not self.args.density_priority:
self.replay.Update_Indices(indices, td_error.cpu().data.numpy(), no_pseudo_in_priority=self.args.count_td_priority)
# If using prioritised we need to weight the td_error
if self.args.prioritized and self.args.prioritized_is:
# print(td_error)
weights_tensor = torch.from_numpy(is_weights).float()
weights_tensor = Variable(weights_tensor)
if self.args.gpu:
weights_tensor = weights_tensor.cuda()
# print(weights_tensor)
td_error = td_error * weights_tensor
# Model 1 step state transition error
# Save them every x steps
if self.T % self.args.model_save_image == 0:
os.makedirs("{}/transition_model/{}".format(self.args.log_path, self.T))
for ii, image, action, next_state, current_state in zip(range(self.args.batch_size), model_predictions_state.cpu().data, actions.data, new_states.cpu().data, states.cpu().data):
image = image.numpy()[0]
image = np.clip(image, 0, 1)
# print(next_state)
next_state = next_state.numpy()[0]
current_state = current_state.numpy()[0]
black_bars = np.zeros_like(next_state[:1, :])
# print(black_bars.shape)
joined_image = np.concatenate((current_state, black_bars, image, black_bars, next_state), axis=0)
joined_image = np.transpose(joined_image)
self.log_image("{}/transition_model/{}/{}_____Action_{}".format(self.args.log_path, self.T, ii + 1, action), joined_image * 255)
# self.log_image("{}/transition_model/{}/{}_____Action_{}".format(self.args.log_path, self.T, ii + 1, action), image * 255)
# self.log_image("{}/transition_model/{}/{}_____Correct".format(self.args.log_path, self.T, ii + 1), next_state * 255)
# print(model_predictions_state)
#.........这里部分代码省略.........
示例2: train
# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import gather [as 别名]
def train(self):
if self.T - self.target_sync_T > self.args.target:
self.sync_target_network()
self.target_sync_T = self.T
info = {}
for _ in range(self.args.iters):
self.dqn.eval()
# TODO: Use a named tuple for experience replay
n_step_sample = 1
if np.random.random() < self.args.n_step_mixing:
n_step_sample = self.args.n_step
batch, indices, is_weights = self.replay.Sample_N(self.args.batch_size, n_step_sample, self.args.gamma)
columns = list(zip(*batch))
states = Variable(torch.from_numpy(np.array(columns[0])).float().transpose_(1, 3))
actions = Variable(torch.LongTensor(columns[1]))
terminal_states = Variable(torch.FloatTensor(columns[5]))
rewards = Variable(torch.FloatTensor(columns[2]))
# Have to clip rewards for DQN
rewards = torch.clamp(rewards, -1, 1)
steps = Variable(torch.FloatTensor(columns[4]))
new_states = Variable(torch.from_numpy(np.array(columns[3])).float().transpose_(1, 3))
target_dqn_qvals = self.target_dqn(new_states).cpu()
# Make a new variable with those values so that these are treated as constants
target_dqn_qvals_data = Variable(target_dqn_qvals.data)
q_value_targets = (Variable(torch.ones(terminal_states.size()[0])) - terminal_states)
inter = Variable(torch.ones(terminal_states.size()[0]) * self.args.gamma)
# print(steps)
q_value_targets = q_value_targets * torch.pow(inter, steps)
if self.args.double:
# Double Q Learning
new_states_qvals = self.dqn(new_states).cpu()
new_states_qvals_data = Variable(new_states_qvals.data)
q_value_targets = q_value_targets * target_dqn_qvals_data.gather(1, new_states_qvals_data.max(1)[1])
else:
q_value_targets = q_value_targets * target_dqn_qvals_data.max(1)[0]
q_value_targets = q_value_targets + rewards
self.dqn.train()
if self.args.gpu:
actions = actions.cuda()
q_value_targets = q_value_targets.cuda()
model_predictions = self.dqn(states).gather(1, actions.view(-1, 1))
# info = {}
td_error = model_predictions - q_value_targets
info["TD_Error"] = td_error.mean().data[0]
# Update the priorities
if not self.args.density_priority:
self.replay.Update_Indices(indices, td_error.cpu().data.numpy(), no_pseudo_in_priority=self.args.count_td_priority)
# If using prioritised we need to weight the td_error
if self.args.prioritized and self.args.prioritized_is:
# print(td_error)
weights_tensor = torch.from_numpy(is_weights).float()
weights_tensor = Variable(weights_tensor)
if self.args.gpu:
weights_tensor = weights_tensor.cuda()
# print(weights_tensor)
td_error = td_error * weights_tensor
l2_loss = (td_error).pow(2).mean()
info["Loss"] = l2_loss.data[0]
# Update
self.optimizer.zero_grad()
l2_loss.backward()
# Taken from pytorch clip_grad_norm
# Remove once the pip version it up to date with source
gradient_norm = clip_grad_norm(self.dqn.parameters(), self.args.clip_value)
if gradient_norm is not None:
info["Norm"] = gradient_norm
self.optimizer.step()
if "States" in info:
states_trained = info["States"]
info["States"] = states_trained + columns[0]
else:
info["States"] = columns[0]
# Pad out the states to be of size batch_size
if len(info["States"]) < self.args.batch_size:
old_states = info["States"]
new_states = old_states[0] * (self.args.batch_size - len(old_states))
info["States"] = new_states
return info