当前位置: 首页>>代码示例>>Python>>正文


Python Variable.volatile方法代码示例

本文整理汇总了Python中torch.autograd.Variable.volatile方法的典型用法代码示例。如果您正苦于以下问题:Python Variable.volatile方法的具体用法?Python Variable.volatile怎么用?Python Variable.volatile使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.autograd.Variable的用法示例。


在下文中一共展示了Variable.volatile方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: accumulate_gradient

# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import volatile [as 别名]
    def accumulate_gradient(self, batch_sz, states, actions, rewards,
                            next_states, mask):
        """ Compute the temporal difference error.
            td_error = (r + gamma * max Q(s_,a)) - Q(s,a)
        """
        states = Variable(states)
        actions = Variable(actions)
        rewards = Variable(rewards)
        next_states = Variable(next_states, volatile=True)

        # Compute Q(s, a)
        q_values = self.policy(states)
        q_values = q_values.gather(1, actions.unsqueeze(1))

        # Compute Q(s_, a)
        q_target_values = None
        if next_states.is_cuda:
            q_target_values = Variable(torch.zeros(batch_sz).cuda())
        else:
            q_target_values = Variable(torch.zeros(batch_sz))

        # Bootstrap for non-terminal states
        q_target_values[mask] = self.target_policy(next_states).max(1)[0][mask]

        q_target_values.volatile = False      # So we don't mess the huber loss
        expected_q_values = (q_target_values * self.gamma) + rewards

        # Compute Huber loss
        loss = F.smooth_l1_loss(q_values, expected_q_values)

        # Accumulate gradients
        loss.backward()
开发者ID:kastnerkyle,项目名称:categorical-dqn,代码行数:34,代码来源:dqn_update.py

示例2: optimize_model

# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import volatile [as 别名]
def optimize_model():
    global last_sync
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
    # detailed explanation).
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)))

    # We don't want to backprop through the expected action values and volatile
    # will save us on temporarily changing the model parameters'
    # requires_grad to False!
    non_final_next_states = Variable(torch.cat([s for s in batch.next_state
                                                if s is not None]),
                                     volatile=True)
    state_batch = Variable(torch.cat(batch.state))
    action_batch = Variable(torch.cat(batch.action))
    reward_batch = Variable(torch.cat(batch.reward))

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken
    state_action_values = model(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))
    next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]
    # Now, we don't want to mess up the loss with a volatile flag, so let's
    # clear it. After this, we'll just end up with a Variable that has
    # requires_grad=False
    next_state_values.volatile = False
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
开发者ID:Alpslee,项目名称:jetson-reinforcement,代码行数:48,代码来源:gym-DQN.py

示例3: test_volatile_fallback

# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import volatile [as 别名]
    def test_volatile_fallback(self):
        """Check that Traceable falls back to num_backwards=0 if given volatile inputs"""
        x = Variable(torch.randn(2, 2))
        y = Variable(torch.randn(2, 2), requires_grad=True)

        @torch.jit.compile
        def fn(x, y):
            return x * x + x * y

        out = fn(x, y)
        self.assertFalse(fn.has_trace_for(x, y))

        x.volatile = True
        self.assertFalse(fn.has_trace_for(x, y))
        out = fn(x, y)
        self.assertTrue(fn.has_trace_for(x, y))
        with self.assertCompiled(fn):
            out2 = fn(x, y)
        self.assertEqual(out, out2)
开发者ID:Northrend,项目名称:pytorch,代码行数:21,代码来源:test_jit.py

示例4: iter

# 需要导入模块: from torch.autograd import Variable [as 别名]
# 或者: from torch.autograd.Variable import volatile [as 别名]
            if ctr % plot_every == 0:
                plot_loss_avg = plot_loss_total / plot_every
                plot_loss_total = 0
                plot_losses.append(plot_loss_avg)

        val_loss_total = 0 # Validation/early stopping
        model.eval()
        if (not args.freeze_models) and args.interpolated_model:
            for member in model.members:
                member.eval()
        for batch in iter(val_iter):
            x_de = batch.src.transpose(1,0).cuda()
            x_en = batch.trg.transpose(1,0).cuda()
            if model_type == 1:
                x_de = flip(x_de,1) # reverse direction
            x_de.volatile = True # "inference mode" supposedly speeds up
            loss, reinforce_loss, avg_reward, _ = model.forward(x_de, x_en)
            # too lazy to implement reward or accuracy for validation
            val_loss_total -= avg_reward
        val_loss_avg = val_loss_total / len(val_iter)
        timenow = timeSince(start)
        current_ppl = np.exp(val_loss_avg)
        print('Validation. Time %s, PPL: %.2f' %(timenow, current_ppl))
        if args.frequent_ckpt:
            torch.save(model.state_dict(), args.model_file) # I'm Paranoid!!!!!!!!!!!!!!!!
            print("Saved Checkpoint")
        elif args.save_best and (current_ppl < best_ppl):
            torch.save(model.state_dict(), args.model_file)
            print("Saved Checkpoint")
            best_ppl = current_ppl
开发者ID:anihamde,项目名称:cs287-s18,代码行数:32,代码来源:main.py


注:本文中的torch.autograd.Variable.volatile方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。