當前位置: 首頁>>代碼示例>>Python>>正文


Python logger.set_snapshot_dir方法代碼示例

本文整理匯總了Python中rllab.misc.logger.set_snapshot_dir方法的典型用法代碼示例。如果您正苦於以下問題:Python logger.set_snapshot_dir方法的具體用法?Python logger.set_snapshot_dir怎麽用?Python logger.set_snapshot_dir使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在rllab.misc.logger的用法示例。


在下文中一共展示了logger.set_snapshot_dir方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: setup_output

# 需要導入模塊: from rllab.misc import logger [as 別名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 別名]
def setup_output(output_dir, clean=True, final_output_dir=None):
    global OUTPUT_DIR
    global FINAL_OUTPUT_DIR
    if OUTPUT_DIR is not None:
        shutdown_output()
    output_dir = os.path.abspath(output_dir)
    if clean:
        ensure_clean_output_dir(output_dir)
    else:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
    OUTPUT_DIR = output_dir
    FINAL_OUTPUT_DIR = final_output_dir

    print("** Output set to", OUTPUT_DIR)
    if FINAL_OUTPUT_DIR is not None:
        print("** Final output set to", FINAL_OUTPUT_DIR)
 
    logger.add_text_output(os.path.join(OUTPUT_DIR, "rllab.txt"))
    logger.add_tabular_output(os.path.join(OUTPUT_DIR, "rllab.csv"))
    logger.set_snapshot_mode('all')     # options: 'none', 'last', or 'all'
    logger.set_snapshot_dir(OUTPUT_DIR) 
開發者ID:vicariousinc,項目名稱:pixelworld,代碼行數:24,代碼來源:output.py

示例2: set_up_experiment

# 需要導入模塊: from rllab.misc import logger [as 別名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 別名]
def set_up_experiment(
        exp_name, 
        phase, 
        exp_home='../data/experiments/',
        snapshot_gap=5):
    maybe_mkdir(exp_home)
    exp_dir = os.path.join(exp_home, exp_name)
    maybe_mkdir(exp_dir)
    phase_dir = os.path.join(exp_dir, phase)
    maybe_mkdir(phase_dir)
    log_dir = os.path.join(phase_dir, 'log')
    maybe_mkdir(log_dir)
    logger.set_snapshot_dir(log_dir)
    logger.set_snapshot_mode('gap')
    logger.set_snapshot_gap(snapshot_gap)
    return exp_dir 
開發者ID:sisl,項目名稱:hgail,代碼行數:18,代碼來源:utils.py

示例3: logdir

# 需要導入模塊: from rllab.misc import logger [as 別名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 別名]
def logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv')) 
開發者ID:nosyndicate,項目名稱:pytorchrl,代碼行數:13,代碼來源:log_utils.py

示例4: rllab_logdir

# 需要導入模塊: from rllab.misc import logger [as 別名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 別名]
def rllab_logdir(algo=None, dirname=None):
    if dirname:
        rllablogger.set_snapshot_dir(dirname)
    dirname = rllablogger.get_snapshot_dir()
    rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
    if algo:
        with open(os.path.join(dirname, 'params.json'), 'w') as f:
            params = extract_hyperparams(algo)
            json.dump(params, f)
    yield dirname
    rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv')) 
開發者ID:ahq1993,項目名稱:inverse_rl,代碼行數:13,代碼來源:log_utils.py

示例5: setup

# 需要導入模塊: from rllab.misc import logger [as 別名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 別名]
def setup(self, env, policy, start_itr):

        if not self.args.algo == 'thddpg':
            # Baseline
            if self.args.baseline_type == 'linear':
                baseline = LinearFeatureBaseline(env_spec=env.spec)
            elif self.args.baseline_type == 'zero':
                baseline = ZeroBaseline(env_spec=env.spec)
            else:
                raise NotImplementedError(self.args.baseline_type)

            if self.args.control == 'concurrent':
                baseline = [baseline for _ in range(len(env.agents))]
        # Logger
        default_log_dir = config.LOG_DIR
        if self.args.log_dir is None:
            log_dir = osp.join(default_log_dir, self.args.exp_name)
        else:
            log_dir = self.args.log_dir

        tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
        text_log_file = osp.join(log_dir, self.args.text_log_file)
        params_log_file = osp.join(log_dir, self.args.params_log_file)

        logger.log_parameters_lite(params_log_file, self.args)
        logger.add_text_output(text_log_file)
        logger.add_tabular_output(tabular_log_file)
        prev_snapshot_dir = logger.get_snapshot_dir()
        prev_mode = logger.get_snapshot_mode()
        logger.set_snapshot_dir(log_dir)
        logger.set_snapshot_mode(self.args.snapshot_mode)
        logger.set_log_tabular_only(self.args.log_tabular_only)
        logger.push_prefix("[%s] " % self.args.exp_name)

        if self.args.algo == 'tftrpo':
            algo = MATRPO(env=env, policy_or_policies=policy, baseline_or_baselines=baseline,
                          batch_size=self.args.batch_size, start_itr=start_itr,
                          max_path_length=self.args.max_path_length, n_itr=self.args.n_iter,
                          discount=self.args.discount, gae_lambda=self.args.gae_lambda,
                          step_size=self.args.step_size, optimizer=ConjugateGradientOptimizer(
                              hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
                          self.args.recurrent else None, ma_mode=self.args.control)
        elif self.args.algo == 'thddpg':
            qfunc = thContinuousMLPQFunction(env_spec=env.spec)
            if self.args.exp_strategy == 'ou':
                es = OUStrategy(env_spec=env.spec)
            elif self.args.exp_strategy == 'gauss':
                es = GaussianStrategy(env_spec=env.spec)
            else:
                raise NotImplementedError()

            algo = thDDPG(env=env, policy=policy, qf=qfunc, es=es, batch_size=self.args.batch_size,
                          max_path_length=self.args.max_path_length,
                          epoch_length=self.args.epoch_length,
                          min_pool_size=self.args.min_pool_size,
                          replay_pool_size=self.args.replay_pool_size, n_epochs=self.args.n_iter,
                          discount=self.args.discount, scale_reward=0.01,
                          qf_learning_rate=self.args.qfunc_lr,
                          policy_learning_rate=self.args.policy_lr,
                          eval_samples=self.args.eval_samples, mode=self.args.control)
        return algo 
開發者ID:sisl,項目名稱:MADRL,代碼行數:63,代碼來源:rurllab.py


注:本文中的rllab.misc.logger.set_snapshot_dir方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。