當前位置: 首頁>>代碼示例>>Python>>正文


Python ext.extract方法代碼示例

本文整理匯總了Python中rllab.misc.ext.extract方法的典型用法代碼示例。如果您正苦於以下問題:Python ext.extract方法的具體用法?Python ext.extract怎麽用?Python ext.extract使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在rllab.misc.ext的用法示例。


在下文中一共展示了ext.extract方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)
        loss_before = self.optimizer.loss(all_input_values)
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        self.optimizer.optimize(all_input_values)
        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKLBefore', mean_kl_before)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
開發者ID:bstadie,項目名稱:third_person_im,代碼行數:24,代碼來源:npo.py

示例2: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        inputs += tuple(state_info_list)
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        loss_before = self.optimizer.loss(inputs)
        self.optimizer.optimize(inputs)
        loss_after = self.optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
開發者ID:bstadie,項目名稱:third_person_im,代碼行數:23,代碼來源:vpg.py

示例3: do_policy_training

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def do_policy_training(self, itr, batch):
        target_policy = self.opt_info["target_policy"]
        obs, = ext.extract(batch, "observations")
        f_train_policy = self.opt_info["f_train_policy"]
        if isinstance(self.policy_update_method, FirstOrderOptimizer):
            policy_surr, _ = f_train_policy(obs)
        else:
            agent_infos = self.policy.dist_info(obs)
            state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
            dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
            all_input_values = (obs, obs, ) + tuple(state_info_list) + tuple(dist_info_list)
            policy_results = f_train_policy(all_input_values)
            policy_surr = policy_results["loss_after"]
        if self.policy_use_target:
            target_policy.set_param_values(
                target_policy.get_param_values() * (1.0 - self.soft_target_tau) +
                self.policy.get_param_values() * self.soft_target_tau)
        self.policy_surr_averages.append(policy_surr) 
開發者ID:shaneshixiang,項目名稱:rllabplusplus,代碼行數:20,代碼來源:ddpg.py

示例4: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        inputs += tuple(state_info_list)
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        dist_info_list = [agent_infos[k]
                          for k in self.policy.distribution.dist_info_keys]
        loss_before = self.optimizer.loss(inputs)
        self.optimizer.optimize(inputs)
        loss_after = self.optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](
            *(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
開發者ID:sisl,項目名稱:gail-driver,代碼行數:25,代碼來源:vpg.py

示例5: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k]
                          for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)
        loss_before = self.optimizer.loss(all_input_values)
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        self.optimizer.optimize(all_input_values)
        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKLBefore', mean_kl_before)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
開發者ID:sisl,項目名稱:gail-driver,代碼行數:25,代碼來源:npo.py

示例6: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        agent_infos = samples_data["agent_infos"]
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        loss_before = self.optimizer.loss(inputs)
        self.optimizer.optimize(inputs)
        loss_after = self.optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
開發者ID:openai,項目名稱:vime,代碼行數:21,代碼來源:vpg_expl.py

示例7: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        info_list = [agent_infos[k]
                     for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)
        loss_before = self.optimizer.loss(all_input_values)
        self.optimizer.optimize(all_input_values)
        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
開發者ID:openai,項目名稱:vime,代碼行數:21,代碼來源:npo_expl.py

示例8: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        inputs += tuple(state_info_list)
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        loss_before = self.optimizer.loss(inputs)
        self.optimizer.optimize(inputs)  # TODO - actual optimize step happens here?
        loss_after = self.optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
開發者ID:cbfinn,項目名稱:maml_rl,代碼行數:23,代碼來源:vpg.py

示例9: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr,
                        samples_data):  # make that samples_data comes with latents: see train in batch_polopt
        all_input_values = tuple(ext.extract(  # it will be in agent_infos!!! under key "latents"
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        all_input_values += (agent_infos[
                                 "latents"],)  # latents has already been processed and is the concat of all latents, but keeps key "latents"
        info_list = [agent_infos[k] for k in
                     self.policy.distribution.dist_info_keys]  # these are the mean and var used at rollout, corresponding to
        all_input_values += tuple(info_list)  # old_dist_info_vars_list as symbolic var
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)

        loss_before = self.optimizer.loss(all_input_values)
        # this should always be 0. If it's not there is a problem.
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        logger.record_tabular('MeanKL_Before', mean_kl_before)

        with logger.prefix(' PolicyOptimize | '):
            self.optimizer.optimize(all_input_values)

        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
開發者ID:florensacc,項目名稱:snn4hrl,代碼行數:31,代碼來源:npo_snn_rewards.py

示例10: do_training

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def do_training(self, itr, batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch,
            "observations", "actions", "rewards", "next_observations",
            "terminals"
        )

        # compute the on-policy y values
        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. - terminals) * self.discount * next_qvals

        f_train_qf = self.opt_info["f_train_qf"]
        f_train_policy = self.opt_info["f_train_policy"]

        qf_loss, qval = f_train_qf(ys, obs, actions)

        policy_surr = f_train_policy(obs)

        target_policy.set_param_values(
            target_policy.get_param_values() * (1.0 - self.soft_target_tau) +
            self.policy.get_param_values() * self.soft_target_tau)
        target_qf.set_param_values(
            target_qf.get_param_values() * (1.0 - self.soft_target_tau) +
            self.qf.get_param_values() * self.soft_target_tau)

        self.qf_loss_averages.append(qf_loss)
        self.policy_surr_averages.append(policy_surr)
        self.q_averages.append(qval)
        self.y_averages.append(ys) 
開發者ID:bstadie,項目名稱:third_person_im,代碼行數:37,代碼來源:ddpg.py

示例11: __setstate__

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def __setstate__(self, d):
        super(ReplayPool, self).__setstate__(d)
        self.bottom, self.top, self.size, self.observations, self.actions, \
        self.rewards, self.terminals, self.extras, self.rng = extract(
            d,
            "bottom", "top", "size", "observations", "actions", "rewards",
            "terminals", "extras", "rng"
        ) 
開發者ID:bstadie,項目名稱:third_person_im,代碼行數:10,代碼來源:util.py

示例12: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)
        logger.log("Computing loss before")
        loss_before = self.optimizer.loss(all_input_values)
        logger.log("Computing KL before")
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        logger.log("Optimizing")
        self.optimizer.optimize(all_input_values)
        logger.log("Computing KL after")
        mean_kl = self.optimizer.constraint_val(all_input_values)
        logger.log("Computing loss after")
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKLBefore', mean_kl_before)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
開發者ID:bstadie,項目名稱:third_person_im,代碼行數:29,代碼來源:npo.py

示例13: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        inputs += tuple(state_info_list)
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        if self.qprop:
            inputs += (samples_data["etas"], )
            logger.log("Using Qprop optimizer")
        optimizer = self.optimizer
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        loss_before = optimizer.loss(inputs)
        gc.collect()
        optimizer.optimize(inputs)
        gc.collect()
        loss_after = optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
開發者ID:shaneshixiang,項目名稱:rllabplusplus,代碼行數:29,代碼來源:vpg.py

示例14: do_training

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def do_training(self, itr, batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch,
            "observations", "actions", "rewards", "next_observations",
            "terminals"
        )

        # compute the on-policy y values
        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. - terminals) * self.discount * next_qvals.reshape(-1)

        f_train_qf = self.opt_info["f_train_qf"]
        qf_loss, qval, _ = f_train_qf(ys, obs, actions)
        target_qf.set_param_values(
            target_qf.get_param_values() * (1.0 - self.soft_target_tau) +
            self.qf.get_param_values() * self.soft_target_tau)
        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)

        self.train_policy_itr += self.policy_updates_ratio
        train_policy_itr = 0
        while self.train_policy_itr > 0:
            f_train_policy = self.opt_info["f_train_policy"]
            policy_surr, _ = f_train_policy(obs)
            target_policy.set_param_values(
                target_policy.get_param_values() * (1.0 - self.soft_target_tau) +
                self.policy.get_param_values() * self.soft_target_tau)
            self.policy_surr_averages.append(policy_surr)
            self.train_policy_itr -= 1
            train_policy_itr += 1
        return 1, train_policy_itr # number of itrs qf, policy are trained 
開發者ID:Breakend,項目名稱:ReproducibilityInContinuousPolicyGradientMethods,代碼行數:40,代碼來源:ddpg.py

示例15: optimize_policy

# 需要導入模塊: from rllab.misc import ext [as 別名]
# 或者: from rllab.misc.ext import extract [as 別名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k]
                          for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)

        self.optimize_policy_from_inputs(all_input_values)
        #logger.log("Computing loss before")
        #loss_before = self.optimizer.loss(all_input_values)
        #logger.log("Computing KL before")
        #mean_kl_before = self.optimizer.constraint_val(all_input_values)
        # logger.log("Optimizing")
        # self.optimizer.optimize(all_input_values)
        #logger.log("Computing KL after")
        #mean_kl = self.optimizer.constraint_val(all_input_values)
        #logger.log("Computing loss after")
        #loss_after = self.optimizer.loss(all_input_values)
        #logger.record_tabular('LossBefore', loss_before)
        #logger.record_tabular('LossAfter', loss_after)
        #logger.record_tabular('MeanKLBefore', mean_kl_before)
        #logger.record_tabular('MeanKL', mean_kl)
        #logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
開發者ID:sisl,項目名稱:gail-driver,代碼行數:32,代碼來源:npo.py


注:本文中的rllab.misc.ext.extract方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。