当前位置: 首页>>代码示例>>Python>>正文


Python logger.add_tabular_output方法代码示例

本文整理汇总了Python中rllab.misc.logger.add_tabular_output方法的典型用法代码示例。如果您正苦于以下问题:Python logger.add_tabular_output方法的具体用法?Python logger.add_tabular_output怎么用?Python logger.add_tabular_output使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在rllab.misc.logger的用法示例。


在下文中一共展示了logger.add_tabular_output方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: setup_output

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import add_tabular_output [as 别名]
def setup_output(output_dir, clean=True, final_output_dir=None):
    global OUTPUT_DIR
    global FINAL_OUTPUT_DIR
    if OUTPUT_DIR is not None:
        shutdown_output()
    output_dir = os.path.abspath(output_dir)
    if clean:
        ensure_clean_output_dir(output_dir)
    else:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
    OUTPUT_DIR = output_dir
    FINAL_OUTPUT_DIR = final_output_dir

    print("** Output set to", OUTPUT_DIR)
    if FINAL_OUTPUT_DIR is not None:
        print("** Final output set to", FINAL_OUTPUT_DIR)
 
    logger.add_text_output(os.path.join(OUTPUT_DIR, "rllab.txt"))
    logger.add_tabular_output(os.path.join(OUTPUT_DIR, "rllab.csv"))
    logger.set_snapshot_mode('all')     # options: 'none', 'last', or 'all'
    logger.set_snapshot_dir(OUTPUT_DIR) 
开发者ID:vicariousinc,项目名称:pixelworld,代码行数:24,代码来源:output.py

示例2: logdir

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import add_tabular_output [as 别名]
def logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv')) 
开发者ID:nosyndicate,项目名称:pytorchrl,代码行数:13,代码来源:log_utils.py

示例3: rllab_logdir

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import add_tabular_output [as 别名]
def rllab_logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv')) 
开发者ID:ahq1993,项目名称:inverse_rl,代码行数:13,代码来源:log_utils.py

示例4: setup

# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import add_tabular_output [as 别名]
def setup(self, env, policy, start_itr):

        if not self.args.algo == 'thddpg':
            # Baseline
            if self.args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif self.args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(self.args.baseline_type)

            if self.args.control == 'concurrent':
                baseline = [baseline for _ in range(len(env.agents))]
        # Logger
        default_log_dir = config.LOG_DIR
        if self.args.log_dir is None:
            log_dir = osp.join(default_log_dir, self.args.exp_name)
        else:
            log_dir = self.args.log_dir

        tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
        text_log_file = osp.join(log_dir, self.args.text_log_file)
        params_log_file = osp.join(log_dir, self.args.params_log_file)

        logger.log_parameters_lite(params_log_file, self.args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(self.args.snapshot_mode)
        logger.set_log_tabular_only(self.args.log_tabular_only)
        logger.push_prefix("[%s] " % self.args.exp_name)

        if self.args.algo == 'tftrpo':
            algo = MATRPO(env=env, policy_or_policies=policy, baseline_or_baselines=baseline,
                          batch_size=self.args.batch_size, start_itr=start_itr,
                          max_path_length=self.args.max_path_length, n_itr=self.args.n_iter,
                          discount=self.args.discount, gae_lambda=self.args.gae_lambda,
                          step_size=self.args.step_size, optimizer=ConjugateGradientOptimizer(
                              hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
                          self.args.recurrent else None, ma_mode=self.args.control)
        elif self.args.algo == 'thddpg':
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if self.args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif self.args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            algo = thDDPG(env=env, policy=policy, qf=qfunc, es=es, batch_size=self.args.batch_size,
                          max_path_length=self.args.max_path_length,
                          epoch_length=self.args.epoch_length,
                          min_pool_size=self.args.min_pool_size,
                          replay_pool_size=self.args.replay_pool_size, n_epochs=self.args.n_iter,
                          discount=self.args.discount, scale_reward=0.01,
                          qf_learning_rate=self.args.qfunc_lr,
                          policy_learning_rate=self.args.policy_lr,
                          eval_samples=self.args.eval_samples, mode=self.args.control)
        return algo 
开发者ID:sisl,项目名称:MADRL,代码行数:63,代码来源:rurllab.py


注:本文中的rllab.misc.logger.add_tabular_output方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。