本文整理汇总了Python中rllab.misc.logger.dump_tabular方法的典型用法代码示例。如果您正苦于以下问题:Python logger.dump_tabular方法的具体用法?Python logger.dump_tabular怎么用?Python logger.dump_tabular使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类rllab.misc.logger
的用法示例。
在下文中一共展示了logger.dump_tabular方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
self.start_worker()
for itr in range(self.current_itr, self.n_itr):
with logger.prefix('itr #%d | ' % itr):
logger.log('Obtaining samples...')
paths = self.sampler.obtain_samples(itr)
logger.log('Processing samples...')
samples_data = self.sampler.process_samples(itr, paths)
logger.log('Logging diagnostics...')
self.log_diagnostics(paths)
logger.log('Optimizing policy...')
self.optimize_policy(itr, samples_data)
logger.log('Saving snapshot...')
params = self.get_itr_snapshot(itr, samples_data)
self.current_itr = itr + 1
params['algo'] = self
# Save the trajectories into the param
if self.store_paths:
params['paths'] = samples_data['paths']
logger.save_itr_params(itr, params)
logger.log('Saved')
logger.dump_tabular(with_prefix=False)
if self.plot:
self.update_plot()
if self.pause_for_plot:
input('Plotting evaluation run: Press Enter to '
'continue...')
self.shutdown_worker()
示例2: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
if self.init_pol_params is not None:
self.policy.set_param_values(self.init_pol_params)
if self.init_irl_params is not None:
self.irl_model.set_params(self.init_irl_params)
self.start_worker()
start_time = time.time()
returns = []
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log('Obtaining samples...')
paths = self.sampler.obtain_samples(itr)
logger.log('Processing samples...')
# Update the Reward function
paths = self.compute_irl(paths, itr=itr)
# returns.append(self.log_avg_returns(paths))
samples_data = self.sampler.process_samples(itr, paths)
logger.log('Logging diagnostics...')
self.log_diagnostics(paths)
logger.log('Optimizing policy...')
self.optimize_policy(itr, samples_data)
logger.log('Saving snapshot...')
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params['paths'] = samples_data['paths']
logger.save_itr_params(itr, params)
logger.log('Saved')
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
self.update_plot()
if self.pause_for_plot:
input('Plotting evaluation run: Press Enter to '
'continue...')
self.shutdown_worker()
return
示例3: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self, sess=None):
created_session = True if (sess is None) else False
if sess is None:
sess = tf.Session()
sess.__enter__()
sess.run(tf.global_variables_initializer())
self.start_worker()
start_time = time.time()
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Obtaining samples...")
paths = self.obtain_samples(itr)
logger.log("Processing samples...")
samples_data = self.process_samples(itr, paths)
logger.log("Logging diagnostics...")
self.log_diagnostics(paths)
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
logger.log("Saved")
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
self.shutdown_worker()
if created_session:
sess.close()
示例4: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
memory = ReplayMem(
obs_dim=self.env.observation_space.flat_dim,
act_dim=self.env.action_space.flat_dim,
memory_size=self.memory_size)
itr = 0
path_length = 0
path_return = 0
end = False
obs = self.env.reset()
for epoch in range(self.n_epochs):
logger.push_prefix("epoch #%d | " % epoch)
logger.log("Training started")
for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
# run the policy
if end:
# reset the environment and stretegy when an episode ends
obs = self.env.reset()
self.strategy.reset()
# self.policy.reset()
self.strategy_path_returns.append(path_return)
path_length = 0
path_return = 0
# note action is sampled from the policy not the target policy
act = self.strategy.get_action(obs, self.policy)
nxt, rwd, end, _ = self.env.step(act)
path_length += 1
path_return += rwd
if not end and path_length >= self.max_path_length:
end = True
if self.include_horizon_terminal:
memory.add_sample(obs, act, rwd, end)
else:
memory.add_sample(obs, act, rwd, end)
obs = nxt
if memory.size >= self.memory_start_size:
for update_time in range(self.n_updates_per_sample):
batch = memory.get_batch(self.batch_size)
self.do_update(itr, batch)
itr += 1
logger.log("Training finished")
if memory.size >= self.memory_start_size:
self.evaluate(epoch, memory)
logger.dump_tabular(with_prefix=False)
logger.pop_prefix()
# self.env.terminate()
# self.policy.terminate()
示例5: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
sess = tf.get_default_session()
sess.run(tf.global_variables_initializer())
if self.init_pol_params is not None:
self.policy.set_param_values(self.init_pol_params)
if self.init_qvar_params is not None:
self.qvar_model.set_params(self.init_qvar_params)
if self.init_irl_params is not None:
self.irl_model.set_params(self.init_irl_params)
if self.init_empw_params is not None:
self.empw.set_params(self.init_empw_params)
self.start_worker()
start_time = time.time()
returns = []
rew = [] # stores score at each step
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Obtaining samples...")
paths = self.obtain_samples(itr)
logger.log("Processing samples...")
paths,r = self.compute_irl(paths, itr=itr)
rew.append(r)
returns.append(self.log_avg_returns(paths))
self.compute_qvar(paths, itr=itr)
self.compute_empw(paths, itr=itr)
samples_data = self.process_samples(itr, paths)
logger.log("Logging diagnostics...")
self.log_diagnostics(paths)
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
logger.log("Saved")
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
self.update_plot()
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
if itr%self.target_empw_update==0 and self.train_empw: #reward 5
print('updating target empowerment parameters')
self.tempw.set_params(self.__empw_params)
#pickle.dump(rew, open("rewards.p", "wb" )) # uncomment to store rewards in every iteration
self.shutdown_worker()
return
示例6: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
sess = tf.get_default_session()
sess.run(tf.global_variables_initializer())
if self.init_pol_params is not None:
self.policy.set_param_values(self.init_pol_params)
if self.init_irl_params is not None:
self.irl_model.set_params(self.init_irl_params)
self.start_worker()
start_time = time.time()
returns = []
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Obtaining samples...")
paths = self.obtain_samples(itr)
logger.log("Processing samples...")
paths = self.compute_irl(paths, itr=itr)
returns.append(self.log_avg_returns(paths))
samples_data = self.process_samples(itr, paths)
logger.log("Logging diagnostics...")
self.log_diagnostics(paths)
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
logger.log("Saved")
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
self.update_plot()
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
self.shutdown_worker()
return
示例7: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import dump_tabular [as 别名]
def train(self):
memory = ReplayMem(
obs_dim=self.env.observation_space.flat_dim,
act_dim=self.env.action_space.flat_dim,
memory_size=self.memory_size)
itr = 0
path_length = 0
path_return = 0
end = False
obs = self.env.reset()
for epoch in xrange(self.n_epochs):
logger.push_prefix("epoch #%d | " % epoch)
logger.log("Training started")
for epoch_itr in pyprind.prog_bar(range(self.epoch_length)):
# run the policy
if end:
# reset the environment and stretegy when an episode ends
obs = self.env.reset()
self.strategy.reset()
# self.policy.reset()
self.strategy_path_returns.append(path_return)
path_length = 0
path_return = 0
# note action is sampled from the policy not the target policy
act = self.strategy.get_action(obs, self.policy)
nxt, rwd, end, _ = self.env.step(act)
path_length += 1
path_return += rwd
if not end and path_length >= self.max_path_length:
end = True
if self.include_horizon_terminal:
memory.add_sample(obs, act, rwd, end)
else:
memory.add_sample(obs, act, rwd, end)
obs = nxt
if memory.size >= self.memory_start_size:
for update_time in xrange(self.n_updates_per_sample):
batch = memory.get_batch(self.batch_size)
self.do_update(itr, batch)
itr += 1
logger.log("Training finished")
if memory.size >= self.memory_start_size:
self.evaluate(epoch, memory)
logger.dump_tabular(with_prefix=False)
logger.pop_prefix()
# self.env.terminate()
# self.policy.terminate()