当前位置: 首页>>代码示例>>Python>>正文


Python autograd.backward方法代码示例

本文整理汇总了Python中torch.autograd.backward方法的典型用法代码示例。如果您正苦于以下问题:Python autograd.backward方法的具体用法?Python autograd.backward怎么用?Python autograd.backward使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.autograd的用法示例。


在下文中一共展示了autograd.backward方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: finish_episode

# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import backward [as 别名]
def finish_episode():
    R = 0
    saved_actions = model.saved_actions
    value_loss = 0
    rewards = []
    for r in model.rewards[::-1]:
        R = r + args.gamma * R
        rewards.insert(0, R)
    rewards = torch.Tensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
    for (action, value), r in zip(saved_actions, rewards):
        reward = r - value.data[0,0]
        action.reinforce(reward)
        value_loss += F.smooth_l1_loss(value, Variable(torch.Tensor([r])))
    optimizer.zero_grad()
    final_nodes = [value_loss] + list(map(lambda p: p.action, saved_actions))
    gradients = [torch.ones(1)] + [None] * len(saved_actions)
    autograd.backward(final_nodes, gradients)
    optimizer.step()
    del model.rewards[:]
    del model.saved_actions[:] 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:23,代码来源:actor_critic.py

示例2: _off_policy_rollout

# 需要导入模块: from torch import autograd [as 别名]
# 或者: from torch.autograd import backward [as 别名]
def _off_policy_rollout(self):
        # reset rollout experiences
        self._reset_rollout()

        # first sample trajectories
        trajectories = self.memory.sample_batch(self.master.batch_size, maxlen=self.master.rollout_steps)
        # NOTE: we also store another set of undetached unsplitted policy_vb here to prepare for backward
        unsplitted_policy_vb = []

        # then fake the on-policy forward
        for t in range(len(trajectories) - 1):
            # we first get the data out of the sampled experience
            state0 = np.stack((trajectory.state0 for trajectory in trajectories[t]))
            action = np.expand_dims(np.stack((trajectory.action for trajectory in trajectories[t])), axis=1)
            reward = np.expand_dims(np.stack((trajectory.reward for trajectory in trajectories[t])), axis=1)
            state1 = np.stack((trajectory.state0 for trajectory in trajectories[t+1]))
            terminal1 = np.expand_dims(np.stack((1 if trajectory.action is None else 0 for trajectory in trajectories[t+1])), axis=1) # NOTE: here is 0/1, in on-policy is False/True
            detached_old_policy_vb = torch.cat([trajectory.detached_old_policy_vb for trajectory in trajectories[t]], 0)

            # NOTE: here first store the last frame: experience.state1 as rollout.state0
            self.rollout.state0.append(state0)
            # then get its corresponding output variables to fake the on policy experience
            if self.master.enable_continuous:
                pass
            else:
                _, p_vb, q_vb, v_vb, avg_p_vb = self._forward(self._preprocessState(self.rollout.state0[-1], on_policy=False), on_policy=False)
            # push experience into rollout
            self.rollout.action.append(action)
            self.rollout.reward.append(reward)
            self.rollout.state1.append(state1)
            self.rollout.terminal1.append(terminal1)
            self.rollout.policy_vb.append(p_vb.split(1, 0)) # NOTE: must split before detach !!! otherwise graph is cut
            self.rollout.q0_vb.append(q_vb)
            self.rollout.value0_vb.append(v_vb)
            self.rollout.detached_avg_policy_vb.append(avg_p_vb.detach()) # NOTE
            self.rollout.detached_old_policy_vb.append(detached_old_policy_vb)
            unsplitted_policy_vb.append(p_vb)

        # also need to log some training stats here maybe

        return unsplitted_policy_vb 
开发者ID:jingweiz,项目名称:pytorch-rl,代码行数:43,代码来源:acer_single_process.py


注:本文中的torch.autograd.backward方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。