本文整理汇总了Python中rllab.misc.logger.get_snapshot_dir方法的典型用法代码示例。如果您正苦于以下问题:Python logger.get_snapshot_dir方法的具体用法?Python logger.get_snapshot_dir怎么用?Python logger.get_snapshot_dir使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类rllab.misc.logger
的用法示例。
在下文中一共展示了logger.get_snapshot_dir方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: logdir
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import get_snapshot_dir [as 别名]
def logdir(algo=None, dirname=None):
if dirname:
rllablogger.set_snapshot_dir(dirname)
dirname = rllablogger.get_snapshot_dir()
rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
if algo:
with open(os.path.join(dirname, 'params.json'), 'w') as f:
params = extract_hyperparams(algo)
json.dump(params, f)
yield dirname
rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
示例2: __init__
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import get_snapshot_dir [as 别名]
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
force_reset=False):
if log_dir is None:
if logger.get_snapshot_dir() is None:
logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
else:
log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
Serializable.quick_init(self, locals())
env = gym.envs.make(env_name)
self.env = env
self.env_id = env.spec.id
assert not (not record_log and record_video)
if log_dir is None or record_log is False:
self.monitoring = False
else:
if not record_video:
video_schedule = NoVideoSchedule()
else:
if video_schedule is None:
video_schedule = CappedCubicVideoSchedule()
self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
self.monitoring = True
self._observation_space = convert_gym_space(env.observation_space)
logger.log("observation space: {}".format(self._observation_space))
self._action_space = convert_gym_space(env.action_space)
logger.log("action space: {}".format(self._action_space))
self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
self._log_dir = log_dir
self._force_reset = force_reset
示例3: rllab_logdir
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import get_snapshot_dir [as 别名]
def rllab_logdir(algo=None, dirname=None):
if dirname:
rllablogger.set_snapshot_dir(dirname)
dirname = rllablogger.get_snapshot_dir()
rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
if algo:
with open(os.path.join(dirname, 'params.json'), 'w') as f:
params = extract_hyperparams(algo)
json.dump(params, f)
yield dirname
rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
示例4: __init__
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import get_snapshot_dir [as 别名]
def __init__(self, path_dir=None):
if path_dir is None:
path_dir = get_snapshot_dir()
self.path_dir = path_dir
self.paths_reader = PathsReader(path_dir)
示例5: __init__
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import get_snapshot_dir [as 别名]
def __init__(self, env_name, record_video=False, video_schedule=None, log_dir=None, record_log=False,
force_reset=True):
if log_dir is None:
if logger.get_snapshot_dir() is None:
logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
else:
log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
Serializable.quick_init(self, locals())
env = gym.envs.make(env_name)
# HACK: Gets rid of the TimeLimit wrapper that sets 'done = True' when
# the time limit specified for each environment has been passed and
# therefore the environment is not Markovian (terminal condition depends
# on time rather than state).
env = env.env
self.env = env
self.env_id = env.spec.id
assert not (not record_log and record_video)
if log_dir is None or record_log is False:
self.monitoring = False
else:
if not record_video:
video_schedule = NoVideoSchedule()
else:
if video_schedule is None:
video_schedule = CappedCubicVideoSchedule()
self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
self.monitoring = True
self._observation_space = convert_gym_space(env.observation_space)
logger.log("observation space: {}".format(self._observation_space))
self._action_space = convert_gym_space(env.action_space)
logger.log("action space: {}".format(self._action_space))
self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
self._log_dir = log_dir
self._force_reset = force_reset
示例6: __init__
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import get_snapshot_dir [as 别名]
def __init__(self, env_name, wrappers=(), wrapper_args=(),
record_video=True, video_schedule=None, log_dir=None, record_log=True,
post_create_env_seed=None,
force_reset=False):
if log_dir is None:
if logger.get_snapshot_dir() is None:
logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
else:
log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
Serializable.quick_init(self, locals())
env = gym.envs.make(env_name)
if post_create_env_seed is not None:
env.set_env_seed(post_create_env_seed)
for i, wrapper in enumerate(wrappers):
if wrapper_args and len(wrapper_args) == len(wrappers):
env = wrapper(env, **wrapper_args[i])
else:
env = wrapper(env)
self.env = env
self.env_id = env.spec.id
assert not (not record_log and record_video)
if log_dir is None or record_log is False:
self.monitoring = False
else:
if not record_video:
video_schedule = NoVideoSchedule()
else:
if video_schedule is None:
video_schedule = CappedCubicVideoSchedule()
self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
self.monitoring = True
self._observation_space = convert_gym_space(env.observation_space)
logger.log("observation space: {}".format(self._observation_space))
self._action_space = convert_gym_space(env.action_space)
logger.log("action space: {}".format(self._action_space))
self._horizon = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
self._log_dir = log_dir
self._force_reset = force_reset
示例7: setup
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import get_snapshot_dir [as 别名]
def setup(self, env, policy, start_itr):
if not self.args.algo == 'thddpg':
# Baseline
if self.args.baseline_type == 'linear':
baseline = LinearFeatureBaseline(env_spec=env.spec)
elif self.args.baseline_type == 'zero':
baseline = ZeroBaseline(env_spec=env.spec)
else:
raise NotImplementedError(self.args.baseline_type)
if self.args.control == 'concurrent':
baseline = [baseline for _ in range(len(env.agents))]
# Logger
default_log_dir = config.LOG_DIR
if self.args.log_dir is None:
log_dir = osp.join(default_log_dir, self.args.exp_name)
else:
log_dir = self.args.log_dir
tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
text_log_file = osp.join(log_dir, self.args.text_log_file)
params_log_file = osp.join(log_dir, self.args.params_log_file)
logger.log_parameters_lite(params_log_file, self.args)
logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode(self.args.snapshot_mode)
logger.set_log_tabular_only(self.args.log_tabular_only)
logger.push_prefix("[%s] " % self.args.exp_name)
if self.args.algo == 'tftrpo':
algo = MATRPO(env=env, policy_or_policies=policy, baseline_or_baselines=baseline,
batch_size=self.args.batch_size, start_itr=start_itr,
max_path_length=self.args.max_path_length, n_itr=self.args.n_iter,
discount=self.args.discount, gae_lambda=self.args.gae_lambda,
step_size=self.args.step_size, optimizer=ConjugateGradientOptimizer(
hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
self.args.recurrent else None, ma_mode=self.args.control)
elif self.args.algo == 'thddpg':
qfunc = thContinuousMLPQFunction(env_spec=env.spec)
if self.args.exp_strategy == 'ou':
es = OUStrategy(env_spec=env.spec)
elif self.args.exp_strategy == 'gauss':
es = GaussianStrategy(env_spec=env.spec)
else:
raise NotImplementedError()
algo = thDDPG(env=env, policy=policy, qf=qfunc, es=es, batch_size=self.args.batch_size,
max_path_length=self.args.max_path_length,
epoch_length=self.args.epoch_length,
min_pool_size=self.args.min_pool_size,
replay_pool_size=self.args.replay_pool_size, n_epochs=self.args.n_iter,
discount=self.args.discount, scale_reward=0.01,
qf_learning_rate=self.args.qfunc_lr,
policy_learning_rate=self.args.policy_lr,
eval_samples=self.args.eval_samples, mode=self.args.control)
return algo