当前位置: 首页>>代码示例>>Python>>正文


Python logger.prefix方法代码示例

本文整理汇总了Python中rllab.misc.logger.prefix方法的典型用法代码示例。如果您正苦于以下问题:Python logger.prefix方法的具体用法?Python logger.prefix怎么用?Python logger.prefix使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在rllab.misc.logger的用法示例。


在下文中一共展示了logger.prefix方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self):
        self.start_worker()
        for itr in range(self.current_itr, self.n_itr):
            with logger.prefix('itr #%d | ' % itr):
                logger.log('Obtaining samples...')
                paths = self.sampler.obtain_samples(itr)
                logger.log('Processing samples...')
                samples_data = self.sampler.process_samples(itr, paths)
                logger.log('Logging diagnostics...')
                self.log_diagnostics(paths)
                logger.log('Optimizing policy...')
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr, samples_data)
                self.current_itr = itr + 1
                params['algo'] = self
                # Save the trajectories into the param
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                logger.save_itr_params(itr, params)
                logger.log('Saved')
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                                  'continue...')

        self.shutdown_worker() 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:31,代码来源:batch_polopt.py

示例2: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self):
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)
        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)
        self.start_worker()
        start_time = time.time()

        returns = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log('Obtaining samples...')
                paths = self.sampler.obtain_samples(itr)
                logger.log('Processing samples...')
                # Update the Reward function
                paths = self.compute_irl(paths, itr=itr)
                # returns.append(self.log_avg_returns(paths))
                samples_data = self.sampler.process_samples(itr, paths)

                logger.log('Logging diagnostics...')
                self.log_diagnostics(paths)
                logger.log('Optimizing policy...')
                self.optimize_policy(itr, samples_data)
                logger.log('Saving snapshot...')
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params['paths'] = samples_data['paths']
                logger.save_itr_params(itr, params)
                logger.log('Saved')
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input('Plotting evaluation run: Press Enter to '
                              'continue...')
        self.shutdown_worker()
        return 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:42,代码来源:irl_batch_polopt.py

示例3: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker()
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        if created_session:
            sess.close() 
开发者ID:ahq1993,项目名称:inverse_rl,代码行数:39,代码来源:batch_polopt.py

示例4: process_samples

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def process_samples(self, itr, paths):
        # count visitations or whatever the bonus wants to do. This should not modify the paths
        for b_eval in self.bonus_evaluator:
            logger.log("fitting bonus evaluator before processing...")
            b_eval.fit_before_process_samples(paths)
            logger.log("fitted")
        # save real undiscounted reward before changing them
        undiscounted_returns = [sum(path["rewards"]) for path in paths]
        logger.record_tabular('TrueAverageReturn', np.mean(undiscounted_returns))
        for path in paths:
            path['true_rewards'] = list(path['rewards'])

        # If using a latent regressor (and possibly adding MI to the reward):
        if isinstance(self.latent_regressor, Latent_regressor):
            with logger.prefix(' Latent_regressor '):
                self.latent_regressor.fit(paths)

                if self.reward_regressor_mi:
                    for i, path in enumerate(paths):
                        path['logli_latent_regressor'] = self.latent_regressor.predict_log_likelihood(
                            [path], [path['agent_infos']['latents']])[0]  # this is for paths usually..

                        path['rewards'] += self.reward_regressor_mi * path[
                            'logli_latent_regressor']  # the logli of the latent is the variable of the mutual information

        # for the extra bonus
        for b, b_eval in enumerate(self.bonus_evaluator):
            for i, path in enumerate(paths):
                bonuses = b_eval.predict(path)
                path['rewards'] += self.reward_coef_bonus[b] * bonuses

        real_samples = ext.extract_dict(
            BatchSampler.process_samples(self, itr, paths),
            # I don't need to process the hallucinated samples: the R, A,.. same!
            "observations", "actions", "advantages", "env_infos", "agent_infos"
        )
        real_samples["importance_weights"] = np.ones_like(real_samples["advantages"])

        return real_samples 
开发者ID:florensacc,项目名称:snn4hrl,代码行数:41,代码来源:npo_snn_rewards.py

示例5: log_diagnostics

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def log_diagnostics(self, paths):
        for b_eval in self.bonus_evaluator:
            b_eval.log_diagnostics(paths)

        if isinstance(self.latent_regressor, Latent_regressor):
            with logger.prefix(' Latent regressor logging | '):
                self.latent_regressor.log_diagnostics(paths) 
开发者ID:florensacc,项目名称:snn4hrl,代码行数:9,代码来源:npo_snn_rewards.py

示例6: optimize_policy

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def optimize_policy(self, itr,
                        samples_data):  # make that samples_data comes with latents: see train in batch_polopt
        all_input_values = tuple(ext.extract(  # it will be in agent_infos!!! under key "latents"
            samples_data,
            "observations", "actions", "advantages"
        ))
        agent_infos = samples_data["agent_infos"]
        all_input_values += (agent_infos[
                                 "latents"],)  # latents has already been processed and is the concat of all latents, but keeps key "latents"
        info_list = [agent_infos[k] for k in
                     self.policy.distribution.dist_info_keys]  # these are the mean and var used at rollout, corresponding to
        all_input_values += tuple(info_list)  # old_dist_info_vars_list as symbolic var
        if self.policy.recurrent:
            all_input_values += (samples_data["valids"],)

        loss_before = self.optimizer.loss(all_input_values)
        # this should always be 0. If it's not there is a problem.
        mean_kl_before = self.optimizer.constraint_val(all_input_values)
        logger.record_tabular('MeanKL_Before', mean_kl_before)

        with logger.prefix(' PolicyOptimize | '):
            self.optimizer.optimize(all_input_values)

        mean_kl = self.optimizer.constraint_val(all_input_values)
        loss_after = self.optimizer.loss(all_input_values)
        logger.record_tabular('LossAfter', loss_after)
        logger.record_tabular('MeanKL', mean_kl)
        logger.record_tabular('dLoss', loss_before - loss_after)
        return dict() 
开发者ID:florensacc,项目名称:snn4hrl,代码行数:31,代码来源:npo_snn_rewards.py

示例7: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self):
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)

        if self.init_qvar_params is not None:
            self.qvar_model.set_params(self.init_qvar_params)

        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)

        if self.init_empw_params is not None:
            self.empw.set_params(self.init_empw_params)

        self.start_worker()
        start_time = time.time()

        returns = []
        rew = [] # stores score at each step
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()

            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)

                logger.log("Processing samples...")
                paths,r = self.compute_irl(paths, itr=itr)
                rew.append(r)
                returns.append(self.log_avg_returns(paths))
                self.compute_qvar(paths, itr=itr)
                self.compute_empw(paths, itr=itr)
                samples_data = self.process_samples(itr, paths)
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
            if itr%self.target_empw_update==0 and self.train_empw:  #reward 5
                print('updating target empowerment parameters')
                self.tempw.set_params(self.__empw_params)


        #pickle.dump(rew, open("rewards.p", "wb" )) # uncomment to store rewards in every iteration
        self.shutdown_worker()
        return 
开发者ID:ahq1993,项目名称:inverse_rl,代码行数:62,代码来源:irl_batch_polopt.py

示例8: train

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self):
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())
        if self.init_pol_params is not None:
            self.policy.set_param_values(self.init_pol_params)
        if self.init_irl_params is not None:
            self.irl_model.set_params(self.init_irl_params)
        self.start_worker()
        start_time = time.time()

        returns = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)

                logger.log("Processing samples...")
                paths = self.compute_irl(paths, itr=itr)
                returns.append(self.log_avg_returns(paths))
                samples_data = self.process_samples(itr, paths)

                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)  # , **kwargs)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.update_plot()
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")
        self.shutdown_worker()
        return 
开发者ID:justinjfu,项目名称:inverse_rl,代码行数:44,代码来源:irl_batch_polopt.py


注:本文中的rllab.misc.logger.prefix方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。