當前位置: 首頁>>代碼示例>>Python>>正文


Python autograd.backward方法代碼示例

本文整理匯總了Python中torch.autograd.backward方法的典型用法代碼示例。如果您正苦於以下問題:Python autograd.backward方法的具體用法?Python autograd.backward怎麽用?Python autograd.backward使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.autograd的用法示例。


在下文中一共展示了autograd.backward方法的2個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: finish_episode

# 需要導入模塊: from torch import autograd [as 別名]
# 或者: from torch.autograd import backward [as 別名]
def finish_episode():
    R = 0
    saved_actions = model.saved_actions
    value_loss = 0
    rewards = []
    for r in model.rewards[::-1]:
        R = r + args.gamma * R
        rewards.insert(0, R)
    rewards = torch.Tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
    for (action, value), r in zip(saved_actions, rewards):
        reward = r - value.data[0,0]
        action.reinforce(reward)
        value_loss += F.smooth_l1_loss(value, Variable(torch.Tensor([r])))
    optimizer.zero_grad()
    final_nodes = [value_loss] + list(map(lambda p: p.action, saved_actions))
    gradients = [torch.ones(1)] + [None] * len(saved_actions)
    autograd.backward(final_nodes, gradients)
    optimizer.step()
    del model.rewards[:]
    del model.saved_actions[:] 
開發者ID:nosyndicate,項目名稱:pytorchrl,代碼行數:23,代碼來源:actor_critic.py

示例2: _off_policy_rollout

# 需要導入模塊: from torch import autograd [as 別名]
# 或者: from torch.autograd import backward [as 別名]
def _off_policy_rollout(self):
        # reset rollout experiences
        self._reset_rollout()

        # first sample trajectories
        trajectories = self.memory.sample_batch(self.master.batch_size, maxlen=self.master.rollout_steps)
        # NOTE: we also store another set of undetached unsplitted policy_vb here to prepare for backward
        unsplitted_policy_vb = []

        # then fake the on-policy forward
        for t in range(len(trajectories) - 1):
            # we first get the data out of the sampled experience
            state0 = np.stack((trajectory.state0 for trajectory in trajectories[t]))
            action = np.expand_dims(np.stack((trajectory.action for trajectory in trajectories[t])), axis=1)
            reward = np.expand_dims(np.stack((trajectory.reward for trajectory in trajectories[t])), axis=1)
            state1 = np.stack((trajectory.state0 for trajectory in trajectories[t+1]))
            terminal1 = np.expand_dims(np.stack((1 if trajectory.action is None else 0 for trajectory in trajectories[t+1])), axis=1) # NOTE: here is 0/1, in on-policy is False/True
            detached_old_policy_vb = torch.cat([trajectory.detached_old_policy_vb for trajectory in trajectories[t]], 0)

            # NOTE: here first store the last frame: experience.state1 as rollout.state0
            self.rollout.state0.append(state0)
            # then get its corresponding output variables to fake the on policy experience
            if self.master.enable_continuous:
                pass
            else:
                _, p_vb, q_vb, v_vb, avg_p_vb = self._forward(self._preprocessState(self.rollout.state0[-1], on_policy=False), on_policy=False)
            # push experience into rollout
            self.rollout.action.append(action)
            self.rollout.reward.append(reward)
            self.rollout.state1.append(state1)
            self.rollout.terminal1.append(terminal1)
            self.rollout.policy_vb.append(p_vb.split(1, 0)) # NOTE: must split before detach !!! otherwise graph is cut
            self.rollout.q0_vb.append(q_vb)
            self.rollout.value0_vb.append(v_vb)
            self.rollout.detached_avg_policy_vb.append(avg_p_vb.detach()) # NOTE
            self.rollout.detached_old_policy_vb.append(detached_old_policy_vb)
            unsplitted_policy_vb.append(p_vb)

        # also need to log some training stats here maybe

        return unsplitted_policy_vb 
開發者ID:jingweiz,項目名稱:pytorch-rl,代碼行數:43,代碼來源:acer_single_process.py


注:本文中的torch.autograd.backward方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。