本文整理汇总了Python中rllab.misc.logger.prefix方法的典型用法代码示例。如果您正苦于以下问题:Python logger.prefix方法的具体用法?Python logger.prefix怎么用?Python logger.prefix使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类rllab.misc.logger
的用法示例。
在下文中一共展示了logger.prefix方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self):
self.start_worker()
for itr in range(self.current_itr, self.n_itr):
with logger.prefix('itr #%d | ' % itr):
logger.log('Obtaining samples...')
paths = self.sampler.obtain_samples(itr)
logger.log('Processing samples...')
samples_data = self.sampler.process_samples(itr, paths)
logger.log('Logging diagnostics...')
self.log_diagnostics(paths)
logger.log('Optimizing policy...')
self.optimize_policy(itr, samples_data)
logger.log('Saving snapshot...')
params = self.get_itr_snapshot(itr, samples_data)
self.current_itr = itr + 1
params['algo'] = self
# Save the trajectories into the param
if self.store_paths:
params['paths'] = samples_data['paths']
logger.save_itr_params(itr, params)
logger.log('Saved')
logger.dump_tabular(with_prefix=False)
if self.plot:
self.update_plot()
if self.pause_for_plot:
input('Plotting evaluation run: Press Enter to '
'continue...')
self.shutdown_worker()
示例2: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self):
if self.init_pol_params is not None:
self.policy.set_param_values(self.init_pol_params)
if self.init_irl_params is not None:
self.irl_model.set_params(self.init_irl_params)
self.start_worker()
start_time = time.time()
returns = []
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log('Obtaining samples...')
paths = self.sampler.obtain_samples(itr)
logger.log('Processing samples...')
# Update the Reward function
paths = self.compute_irl(paths, itr=itr)
# returns.append(self.log_avg_returns(paths))
samples_data = self.sampler.process_samples(itr, paths)
logger.log('Logging diagnostics...')
self.log_diagnostics(paths)
logger.log('Optimizing policy...')
self.optimize_policy(itr, samples_data)
logger.log('Saving snapshot...')
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params['paths'] = samples_data['paths']
logger.save_itr_params(itr, params)
logger.log('Saved')
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
self.update_plot()
if self.pause_for_plot:
input('Plotting evaluation run: Press Enter to '
'continue...')
self.shutdown_worker()
return
示例3: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self, sess=None):
created_session = True if (sess is None) else False
if sess is None:
sess = tf.Session()
sess.__enter__()
sess.run(tf.global_variables_initializer())
self.start_worker()
start_time = time.time()
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Obtaining samples...")
paths = self.obtain_samples(itr)
logger.log("Processing samples...")
samples_data = self.process_samples(itr, paths)
logger.log("Logging diagnostics...")
self.log_diagnostics(paths)
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
logger.log("Saved")
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
self.shutdown_worker()
if created_session:
sess.close()
示例4: process_samples
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def process_samples(self, itr, paths):
# count visitations or whatever the bonus wants to do. This should not modify the paths
for b_eval in self.bonus_evaluator:
logger.log("fitting bonus evaluator before processing...")
b_eval.fit_before_process_samples(paths)
logger.log("fitted")
# save real undiscounted reward before changing them
undiscounted_returns = [sum(path["rewards"]) for path in paths]
logger.record_tabular('TrueAverageReturn', np.mean(undiscounted_returns))
for path in paths:
path['true_rewards'] = list(path['rewards'])
# If using a latent regressor (and possibly adding MI to the reward):
if isinstance(self.latent_regressor, Latent_regressor):
with logger.prefix(' Latent_regressor '):
self.latent_regressor.fit(paths)
if self.reward_regressor_mi:
for i, path in enumerate(paths):
path['logli_latent_regressor'] = self.latent_regressor.predict_log_likelihood(
[path], [path['agent_infos']['latents']])[0] # this is for paths usually..
path['rewards'] += self.reward_regressor_mi * path[
'logli_latent_regressor'] # the logli of the latent is the variable of the mutual information
# for the extra bonus
for b, b_eval in enumerate(self.bonus_evaluator):
for i, path in enumerate(paths):
bonuses = b_eval.predict(path)
path['rewards'] += self.reward_coef_bonus[b] * bonuses
real_samples = ext.extract_dict(
BatchSampler.process_samples(self, itr, paths),
# I don't need to process the hallucinated samples: the R, A,.. same!
"observations", "actions", "advantages", "env_infos", "agent_infos"
)
real_samples["importance_weights"] = np.ones_like(real_samples["advantages"])
return real_samples
示例5: log_diagnostics
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def log_diagnostics(self, paths):
for b_eval in self.bonus_evaluator:
b_eval.log_diagnostics(paths)
if isinstance(self.latent_regressor, Latent_regressor):
with logger.prefix(' Latent regressor logging | '):
self.latent_regressor.log_diagnostics(paths)
示例6: optimize_policy
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def optimize_policy(self, itr,
samples_data): # make that samples_data comes with latents: see train in batch_polopt
all_input_values = tuple(ext.extract( # it will be in agent_infos!!! under key "latents"
samples_data,
"observations", "actions", "advantages"
))
agent_infos = samples_data["agent_infos"]
all_input_values += (agent_infos[
"latents"],) # latents has already been processed and is the concat of all latents, but keeps key "latents"
info_list = [agent_infos[k] for k in
self.policy.distribution.dist_info_keys] # these are the mean and var used at rollout, corresponding to
all_input_values += tuple(info_list) # old_dist_info_vars_list as symbolic var
if self.policy.recurrent:
all_input_values += (samples_data["valids"],)
loss_before = self.optimizer.loss(all_input_values)
# this should always be 0. If it's not there is a problem.
mean_kl_before = self.optimizer.constraint_val(all_input_values)
logger.record_tabular('MeanKL_Before', mean_kl_before)
with logger.prefix(' PolicyOptimize | '):
self.optimizer.optimize(all_input_values)
mean_kl = self.optimizer.constraint_val(all_input_values)
loss_after = self.optimizer.loss(all_input_values)
logger.record_tabular('LossAfter', loss_after)
logger.record_tabular('MeanKL', mean_kl)
logger.record_tabular('dLoss', loss_before - loss_after)
return dict()
示例7: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self):
sess = tf.get_default_session()
sess.run(tf.global_variables_initializer())
if self.init_pol_params is not None:
self.policy.set_param_values(self.init_pol_params)
if self.init_qvar_params is not None:
self.qvar_model.set_params(self.init_qvar_params)
if self.init_irl_params is not None:
self.irl_model.set_params(self.init_irl_params)
if self.init_empw_params is not None:
self.empw.set_params(self.init_empw_params)
self.start_worker()
start_time = time.time()
returns = []
rew = [] # stores score at each step
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Obtaining samples...")
paths = self.obtain_samples(itr)
logger.log("Processing samples...")
paths,r = self.compute_irl(paths, itr=itr)
rew.append(r)
returns.append(self.log_avg_returns(paths))
self.compute_qvar(paths, itr=itr)
self.compute_empw(paths, itr=itr)
samples_data = self.process_samples(itr, paths)
logger.log("Logging diagnostics...")
self.log_diagnostics(paths)
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
logger.log("Saved")
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
self.update_plot()
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
if itr%self.target_empw_update==0 and self.train_empw: #reward 5
print('updating target empowerment parameters')
self.tempw.set_params(self.__empw_params)
#pickle.dump(rew, open("rewards.p", "wb" )) # uncomment to store rewards in every iteration
self.shutdown_worker()
return
示例8: train
# 需要导入模块: from rllab.misc import logger [as 别名]
# 或者: from rllab.misc.logger import prefix [as 别名]
def train(self):
sess = tf.get_default_session()
sess.run(tf.global_variables_initializer())
if self.init_pol_params is not None:
self.policy.set_param_values(self.init_pol_params)
if self.init_irl_params is not None:
self.irl_model.set_params(self.init_irl_params)
self.start_worker()
start_time = time.time()
returns = []
for itr in range(self.start_itr, self.n_itr):
itr_start_time = time.time()
with logger.prefix('itr #%d | ' % itr):
logger.log("Obtaining samples...")
paths = self.obtain_samples(itr)
logger.log("Processing samples...")
paths = self.compute_irl(paths, itr=itr)
returns.append(self.log_avg_returns(paths))
samples_data = self.process_samples(itr, paths)
logger.log("Logging diagnostics...")
self.log_diagnostics(paths)
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
if self.store_paths:
params["paths"] = samples_data["paths"]
logger.save_itr_params(itr, params)
logger.log("Saved")
logger.record_tabular('Time', time.time() - start_time)
logger.record_tabular('ItrTime', time.time() - itr_start_time)
logger.dump_tabular(with_prefix=False)
if self.plot:
self.update_plot()
if self.pause_for_plot:
input("Plotting evaluation run: Press Enter to "
"continue...")
self.shutdown_worker()
return