当前位置: 首页>>代码示例>>Python>>正文


Python ext.extract方法代码示例

本文整理汇总了Python中rllab.misc.ext.extract方法的典型用法代码示例。如果您正苦于以下问题:Python ext.extract方法的具体用法?Python ext.extract怎么用?Python ext.extract使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在rllab.misc.ext的用法示例。


在下文中一共展示了ext.extract方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)
        loss_before = self.optimizer.loss(all_input_values)
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        self.optimizer.optimize(all_input_values)
        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKLBefore', mean_kl_before)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
开发者ID:bstadie,项目名称:third_person_im,代码行数:24,代码来源:npo.py

示例2: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        inputs += tuple(state_info_list)
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        loss_before = self.optimizer.loss(inputs)
        self.optimizer.optimize(inputs)
        loss_after = self.optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
开发者ID:bstadie,项目名称:third_person_im,代码行数:23,代码来源:vpg.py

示例3: do_policy_training

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def do_policy_training(self, itr, batch):
        target_policy = self.opt_info["target_policy"]
        obs, = ext.extract(batch, "observations")
        f_train_policy = self.opt_info["f_train_policy"]
        if isinstance(self.policy_update_method, FirstOrderOptimizer):
            policy_surr, _ = f_train_policy(obs)
        else:
            agent_infos = self.policy.dist_info(obs)
            state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
            dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
            all_input_values = (obs, obs, ) + tuple(state_info_list) + tuple(dist_info_list)
            policy_results = f_train_policy(all_input_values)
            policy_surr = policy_results["loss_after"]
        if self.policy_use_target:
            target_policy.set_param_values(
                target_policy.get_param_values() * (1.0 - self.soft_target_tau) +
                self.policy.get_param_values() * self.soft_target_tau)
        self.policy_surr_averages.append(policy_surr) 
开发者ID:shaneshixiang,项目名称:rllabplusplus,代码行数:20,代码来源:ddpg.py

示例4: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        inputs += tuple(state_info_list)
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        dist_info_list = [agent_infos[k]
                          for k in self.policy.distribution.dist_info_keys]
        loss_before = self.optimizer.loss(inputs)
        self.optimizer.optimize(inputs)
        loss_after = self.optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](
            *(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
开发者ID:sisl,项目名称:gail-driver,代码行数:25,代码来源:vpg.py

示例5: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k]
                          for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)
        loss_before = self.optimizer.loss(all_input_values)
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        self.optimizer.optimize(all_input_values)
        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKLBefore', mean_kl_before)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
开发者ID:sisl,项目名称:gail-driver,代码行数:25,代码来源:npo.py

示例6: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        agent_infos = samples_data["agent_infos"]
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        loss_before = self.optimizer.loss(inputs)
        self.optimizer.optimize(inputs)
        loss_after = self.optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
开发者ID:openai,项目名称:vime,代码行数:21,代码来源:vpg_expl.py

示例7: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        info_list = [agent_infos[k]
                     for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)
        loss_before = self.optimizer.loss(all_input_values)
        self.optimizer.optimize(all_input_values)
        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
开发者ID:openai,项目名称:vime,代码行数:21,代码来源:npo_expl.py

示例8: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        inputs += tuple(state_info_list)
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        loss_before = self.optimizer.loss(inputs)
        self.optimizer.optimize(inputs)  # TODO - actual optimize step happens here?
        loss_after = self.optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
开发者ID:cbfinn,项目名称:maml_rl,代码行数:23,代码来源:vpg.py

示例9: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr,
                        samples_data):  # make that samples_data comes with latents: see train in batch_polopt
        all_input_values = tuple(ext.extract(  # it will be in agent_infos!!! under key "latents"
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        all_input_values += (agent_infos[
                                 "latents"],)  # latents has already been processed and is the concat of all latents, but keeps key "latents"
        info_list = [agent_infos[k] for k in
                     self.policy.distribution.dist_info_keys]  # these are the mean and var used at rollout, corresponding to
        all_input_values += tuple(info_list)  # old_dist_info_vars_list as symbolic var
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)

        loss_before = self.optimizer.loss(all_input_values)
        # this should always be 0. If it's not there is a problem.
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        logger.record_tabular('MeanKL_Before', mean_kl_before)

        with logger.prefix(' PolicyOptimize | '):
            self.optimizer.optimize(all_input_values)

        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
开发者ID:florensacc,项目名称:snn4hrl,代码行数:31,代码来源:npo_snn_rewards.py

示例10: do_training

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def do_training(self, itr, batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch,
            "observations", "actions", "rewards", "next_observations",
            "terminals"
        )

        # compute the on-policy y values
        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. - terminals) * self.discount * next_qvals

        f_train_qf = self.opt_info["f_train_qf"]
        f_train_policy = self.opt_info["f_train_policy"]

        qf_loss, qval = f_train_qf(ys, obs, actions)

        policy_surr = f_train_policy(obs)

        target_policy.set_param_values(
            target_policy.get_param_values() * (1.0 - self.soft_target_tau) +
            self.policy.get_param_values() * self.soft_target_tau)
        target_qf.set_param_values(
            target_qf.get_param_values() * (1.0 - self.soft_target_tau) +
            self.qf.get_param_values() * self.soft_target_tau)

        self.qf_loss_averages.append(qf_loss)
        self.policy_surr_averages.append(policy_surr)
        self.q_averages.append(qval)
        self.y_averages.append(ys) 
开发者ID:bstadie,项目名称:third_person_im,代码行数:37,代码来源:ddpg.py

示例11: __setstate__

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def __setstate__(self, d):
        super(ReplayPool, self).__setstate__(d)
        self.bottom, self.top, self.size, self.observations, self.actions, \
        self.rewards, self.terminals, self.extras, self.rng = extract(
            d,
            "bottom", "top", "size", "observations", "actions", "rewards",
            "terminals", "extras", "rng"
        ) 
开发者ID:bstadie,项目名称:third_person_im,代码行数:10,代码来源:util.py

示例12: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)
        logger.log("Computing loss before")
        loss_before = self.optimizer.loss(all_input_values)
        logger.log("Computing KL before")
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        logger.log("Optimizing")
        self.optimizer.optimize(all_input_values)
        logger.log("Computing KL after")
        mean_kl = self.optimizer.constraint_val(all_input_values)
        logger.log("Computing loss after")
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossBefore', loss_before)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKLBefore', mean_kl_before)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
开发者ID:bstadie,项目名称:third_person_im,代码行数:29,代码来源:npo.py

示例13: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        logger.log("optimizing policy")
        inputs = ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        )
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        inputs += tuple(state_info_list)
        if self.policy.recurrent:
            inputs += (samples_data["valids"],)
        if self.qprop:
            inputs += (samples_data["etas"], )
            logger.log("Using Qprop optimizer")
        optimizer = self.optimizer
        dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
        loss_before = optimizer.loss(inputs)
        gc.collect()
        optimizer.optimize(inputs)
        gc.collect()
        loss_after = optimizer.loss(inputs)
        logger.record_tabular("LossBefore", loss_before)
        logger.record_tabular("LossAfter", loss_after)

        mean_kl, max_kl = self.opt_info['f_kl'](*(list(inputs) + dist_info_list))
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('MaxKL', max_kl) 
开发者ID:shaneshixiang,项目名称:rllabplusplus,代码行数:29,代码来源:vpg.py

示例14: do_training

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def do_training(self, itr, batch):

        obs, actions, rewards, next_obs, terminals = ext.extract(
            batch,
            "observations", "actions", "rewards", "next_observations",
            "terminals"
        )

        # compute the on-policy y values
        target_qf = self.opt_info["target_qf"]
        target_policy = self.opt_info["target_policy"]

        next_actions, _ = target_policy.get_actions(next_obs)
        next_qvals = target_qf.get_qval(next_obs, next_actions)

        ys = rewards + (1. - terminals) * self.discount * next_qvals.reshape(-1)

        f_train_qf = self.opt_info["f_train_qf"]
        qf_loss, qval, _ = f_train_qf(ys, obs, actions)
        target_qf.set_param_values(
            target_qf.get_param_values() * (1.0 - self.soft_target_tau) +
            self.qf.get_param_values() * self.soft_target_tau)
        self.qf_loss_averages.append(qf_loss)
        self.q_averages.append(qval)
        self.y_averages.append(ys)

        self.train_policy_itr += self.policy_updates_ratio
        train_policy_itr = 0
        while self.train_policy_itr > 0:
            f_train_policy = self.opt_info["f_train_policy"]
            policy_surr, _ = f_train_policy(obs)
            target_policy.set_param_values(
                target_policy.get_param_values() * (1.0 - self.soft_target_tau) +
                self.policy.get_param_values() * self.soft_target_tau)
            self.policy_surr_averages.append(policy_surr)
            self.train_policy_itr -= 1
            train_policy_itr += 1
        return 1, train_policy_itr # number of itrs qf, policy are trained 
开发者ID:Breakend,项目名称:ReproducibilityInContinuousPolicyGradientMethods,代码行数:40,代码来源:ddpg.py

示例15: optimize_policy

# 需要导入模块: from rllab.misc import ext [as 别名]
# 或者: from rllab.misc.ext import extract [as 别名]
def optimize_policy(self, itr, samples_data):
        all_input_values = tuple(ext.extract(
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
        dist_info_list = [agent_infos[k]
                          for k in self.policy.distribution.dist_info_keys]
        all_input_values += tuple(state_info_list) + tuple(dist_info_list)
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)

        self.optimize_policy_from_inputs(all_input_values)
        #logger.log("Computing loss before")
        #loss_before = self.optimizer.loss(all_input_values)
        #logger.log("Computing KL before")
        #mean_kl_before = self.optimizer.constraint_val(all_input_values)
        # logger.log("Optimizing")
        # self.optimizer.optimize(all_input_values)
        #logger.log("Computing KL after")
        #mean_kl = self.optimizer.constraint_val(all_input_values)
        #logger.log("Computing loss after")
        #loss_after = self.optimizer.loss(all_input_values)
        #logger.record_tabular('LossBefore', loss_before)
        #logger.record_tabular('LossAfter', loss_after)
        #logger.record_tabular('MeanKLBefore', mean_kl_before)
        #logger.record_tabular('MeanKL', mean_kl)
        #logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
开发者ID:sisl,项目名称:gail-driver,代码行数:32,代码来源:npo.py


注:本文中的rllab.misc.ext.extract方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。