本文整理汇总了Python中rllab.misc.logger.set_snapshot_dir方法的典型用法代码示例。如果您正苦于以下问题:Python logger.set_snapshot_dir方法的具体用法?Python logger.set_snapshot_dir怎么用?Python logger.set_snapshot_dir使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类rllab.misc.logger
的用法示例。
在下文中一共展示了logger.set_snapshot_dir方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: setup_output
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 别名]
def setup_output(output_dir, clean=True, final_output_dir=None):
global OUTPUT_DIR
global FINAL_OUTPUT_DIR
if OUTPUT_DIR is not None:
shutdown_output()
output_dir = os.path.abspath(output_dir)
if clean:
ensure_clean_output_dir(output_dir)
else:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
OUTPUT_DIR = output_dir
FINAL_OUTPUT_DIR = final_output_dir
print("** Output set to", OUTPUT_DIR)
if FINAL_OUTPUT_DIR is not None:
print("** Final output set to", FINAL_OUTPUT_DIR)
logger.add_text_output(os.path.join(OUTPUT_DIR, "rllab.txt"))
logger.add_tabular_output(os.path.join(OUTPUT_DIR, "rllab.csv"))
logger.set_snapshot_mode('all') # options: 'none', 'last', or 'all'
logger.set_snapshot_dir(OUTPUT_DIR)
示例2: set_up_experiment
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 别名]
def set_up_experiment(
exp_name,
phase,
exp_home='../data/experiments/',
snapshot_gap=5):
maybe_mkdir(exp_home)
exp_dir = os.path.join(exp_home, exp_name)
maybe_mkdir(exp_dir)
phase_dir = os.path.join(exp_dir, phase)
maybe_mkdir(phase_dir)
log_dir = os.path.join(phase_dir, 'log')
maybe_mkdir(log_dir)
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode('gap')
logger.set_snapshot_gap(snapshot_gap)
return exp_dir
示例3: logdir
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 别名]
def logdir(algo=None, dirname=None):
if dirname:
rllablogger.set_snapshot_dir(dirname)
dirname = rllablogger.get_snapshot_dir()
rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
if algo:
with open(os.path.join(dirname, 'params.json'), 'w') as f:
params = extract_hyperparams(algo)
json.dump(params, f)
yield dirname
rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
示例4: rllab_logdir
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 别名]
def rllab_logdir(algo=None, dirname=None):
if dirname:
rllablogger.set_snapshot_dir(dirname)
dirname = rllablogger.get_snapshot_dir()
rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
if algo:
with open(os.path.join(dirname, 'params.json'), 'w') as f:
params = extract_hyperparams(algo)
json.dump(params, f)
yield dirname
rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
示例5: setup
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import set_snapshot_dir [as 别名]
def setup(self, env, policy, start_itr):
if not self.args.algo == 'thddpg':
# Baseline
if self.args.baseline_type == 'linear':
baseline = LinearFeatureBaseline(env_spec=env.spec)
elif self.args.baseline_type == 'zero':
baseline = ZeroBaseline(env_spec=env.spec)
else:
raise NotImplementedError(self.args.baseline_type)
if self.args.control == 'concurrent':
baseline = [baseline for _ in range(len(env.agents))]
# Logger
default_log_dir = config.LOG_DIR
if self.args.log_dir is None:
log_dir = osp.join(default_log_dir, self.args.exp_name)
else:
log_dir = self.args.log_dir
tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
text_log_file = osp.join(log_dir, self.args.text_log_file)
params_log_file = osp.join(log_dir, self.args.params_log_file)
logger.log_parameters_lite(params_log_file, self.args)
logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode(self.args.snapshot_mode)
logger.set_log_tabular_only(self.args.log_tabular_only)
logger.push_prefix("[%s] " % self.args.exp_name)
if self.args.algo == 'tftrpo':
algo = MATRPO(env=env, policy_or_policies=policy, baseline_or_baselines=baseline,
batch_size=self.args.batch_size, start_itr=start_itr,
max_path_length=self.args.max_path_length, n_itr=self.args.n_iter,
discount=self.args.discount, gae_lambda=self.args.gae_lambda,
step_size=self.args.step_size, optimizer=ConjugateGradientOptimizer(
hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
self.args.recurrent else None, ma_mode=self.args.control)
elif self.args.algo == 'thddpg':
qfunc = thContinuousMLPQFunction(env_spec=env.spec)
if self.args.exp_strategy == 'ou':
es = OUStrategy(env_spec=env.spec)
elif self.args.exp_strategy == 'gauss':
es = GaussianStrategy(env_spec=env.spec)
else:
raise NotImplementedError()
algo = thDDPG(env=env, policy=policy, qf=qfunc, es=es, batch_size=self.args.batch_size,
max_path_length=self.args.max_path_length,
epoch_length=self.args.epoch_length,
min_pool_size=self.args.min_pool_size,
replay_pool_size=self.args.replay_pool_size, n_epochs=self.args.n_iter,
discount=self.args.discount, scale_reward=0.01,
qf_learning_rate=self.args.qfunc_lr,
policy_learning_rate=self.args.policy_lr,
eval_samples=self.args.eval_samples, mode=self.args.control)
return algo