本文整理汇总了Python中rllab.misc.logger.push_prefix方法的典型用法代码示例。如果您正苦于以下问题:Python logger.push_prefix方法的具体用法?Python logger.push_prefix怎么用?Python logger.push_prefix使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类rllab.misc.logger
的用法示例。
在下文中一共展示了logger.push_prefix方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import push_prefix [as 别名]
def train(self):
memory = ReplayMem(
obs_dim=self.env.observation_space.flat_dim,
act_dim=self.env.action_space.flat_dim,
memory_size=self.memory_size)
itr = 0
path_length = 0
path_return = 0
end = False
obs = self.env.reset()
for epoch in range(self.n_epochs):
logger.push_prefix("epoch #%d | " % epoch)
logger.log("Training started")
for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
# run the policy
if end:
# reset the environment and stretegy when an episode ends
obs = self.env.reset()
self.strategy.reset()
# self.policy.reset()
self.strategy_path_returns.append(path_return)
path_length = 0
path_return = 0
# note action is sampled from the policy not the target policy
act = self.strategy.get_action(obs, self.policy)
nxt, rwd, end, _ = self.env.step(act)
path_length += 1
path_return += rwd
if not end and path_length >= self.max_path_length:
end = True
if self.include_horizon_terminal:
memory.add_sample(obs, act, rwd, end)
else:
memory.add_sample(obs, act, rwd, end)
obs = nxt
if memory.size >= self.memory_start_size:
for update_time in range(self.n_updates_per_sample):
batch = memory.get_batch(self.batch_size)
self.do_update(itr, batch)
itr += 1
logger.log("Training finished")
if memory.size >= self.memory_start_size:
self.evaluate(epoch, memory)
logger.dump_tabular(with_prefix=False)
logger.pop_prefix()
# self.env.terminate()
# self.policy.terminate()
示例2: setup
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import push_prefix [as 别名]
def setup(self, env, policy, start_itr):
if not self.args.algo == 'thddpg':
# Baseline
if self.args.baseline_type == 'linear':
baseline = LinearFeatureBaseline(env_spec=env.spec)
elif self.args.baseline_type == 'zero':
baseline = ZeroBaseline(env_spec=env.spec)
else:
raise NotImplementedError(self.args.baseline_type)
if self.args.control == 'concurrent':
baseline = [baseline for _ in range(len(env.agents))]
# Logger
default_log_dir = config.LOG_DIR
if self.args.log_dir is None:
log_dir = osp.join(default_log_dir, self.args.exp_name)
else:
log_dir = self.args.log_dir
tabular_log_file = osp.join(log_dir, self.args.tabular_log_file)
text_log_file = osp.join(log_dir, self.args.text_log_file)
params_log_file = osp.join(log_dir, self.args.params_log_file)
logger.log_parameters_lite(params_log_file, self.args)
logger.add_text_output(text_log_file)
logger.add_tabular_output(tabular_log_file)
prev_snapshot_dir = logger.get_snapshot_dir()
prev_mode = logger.get_snapshot_mode()
logger.set_snapshot_dir(log_dir)
logger.set_snapshot_mode(self.args.snapshot_mode)
logger.set_log_tabular_only(self.args.log_tabular_only)
logger.push_prefix("[%s] " % self.args.exp_name)
if self.args.algo == 'tftrpo':
algo = MATRPO(env=env, policy_or_policies=policy, baseline_or_baselines=baseline,
batch_size=self.args.batch_size, start_itr=start_itr,
max_path_length=self.args.max_path_length, n_itr=self.args.n_iter,
discount=self.args.discount, gae_lambda=self.args.gae_lambda,
step_size=self.args.step_size, optimizer=ConjugateGradientOptimizer(
hvp_approach=FiniteDifferenceHvp(base_eps=1e-5)) if
self.args.recurrent else None, ma_mode=self.args.control)
elif self.args.algo == 'thddpg':
qfunc = thContinuousMLPQFunction(env_spec=env.spec)
if self.args.exp_strategy == 'ou':
es = OUStrategy(env_spec=env.spec)
elif self.args.exp_strategy == 'gauss':
es = GaussianStrategy(env_spec=env.spec)
else:
raise NotImplementedError()
algo = thDDPG(env=env, policy=policy, qf=qfunc, es=es, batch_size=self.args.batch_size,
max_path_length=self.args.max_path_length,
epoch_length=self.args.epoch_length,
min_pool_size=self.args.min_pool_size,
replay_pool_size=self.args.replay_pool_size, n_epochs=self.args.n_iter,
discount=self.args.discount, scale_reward=0.01,
qf_learning_rate=self.args.qfunc_lr,
policy_learning_rate=self.args.policy_lr,
eval_samples=self.args.eval_samples, mode=self.args.control)
return algo
示例3: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import push_prefix [as 别名]
def train(self):
memory = ReplayMem(
obs_dim=self.env.observation_space.flat_dim,
act_dim=self.env.action_space.flat_dim,
memory_size=self.memory_size)
itr = 0
path_length = 0
path_return = 0
end = False
obs = self.env.reset()
for epoch in xrange(self.n_epochs):
logger.push_prefix("epoch #%d | " % epoch)
logger.log("Training started")
for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
# run the policy
if end:
# reset the environment and stretegy when an episode ends
obs = self.env.reset()
self.strategy.reset()
# self.policy.reset()
self.strategy_path_returns.append(path_return)
path_length = 0
path_return = 0
# note action is sampled from the policy not the target policy
act = self.strategy.get_action(obs, self.policy)
nxt, rwd, end, _ = self.env.step(act)
path_length += 1
path_return += rwd
if not end and path_length >= self.max_path_length:
end = True
if self.include_horizon_terminal:
memory.add_sample(obs, act, rwd, end)
else:
memory.add_sample(obs, act, rwd, end)
obs = nxt
if memory.size >= self.memory_start_size:
for update_time in xrange(self.n_updates_per_sample):
batch = memory.get_batch(self.batch_size)
self.do_update(itr, batch)
itr += 1
logger.log("Training finished")
if memory.size >= self.memory_start_size:
self.evaluate(epoch, memory)
logger.dump_tabular(with_prefix=False)
logger.pop_prefix()
# self.env.terminate()
# self.policy.terminate()