本文整理汇总了Python中baselines.her.rollout.RolloutWorker方法的典型用法代码示例。如果您正苦于以下问题:Python rollout.RolloutWorker方法的具体用法?Python rollout.RolloutWorker怎么用?Python rollout.RolloutWorker使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类baselines.her.rollout
的用法示例。
在下文中一共展示了rollout.RolloutWorker方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from baselines.her import rollout [as 别名]
# 或者: from baselines.her.rollout import RolloutWorker [as 别名]
def main(policy_file, seed, n_test_rollouts, render):
set_global_seeds(seed)
# Load policy.
with open(policy_file, 'rb') as f:
policy = pickle.load(f)
env_name = policy.info['env_name']
# Prepare params.
params = config.DEFAULT_PARAMS
if env_name in config.DEFAULT_ENV_PARAMS:
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
params['env_name'] = env_name
params = config.prepare_params(params)
config.log_params(params, logger=logger)
dims = config.configure_dims(params)
eval_params = {
'exploit': True,
'use_target_net': params['test_with_polyak'],
'compute_Q': True,
'rollout_batch_size': 1,
'render': bool(render),
}
for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
eval_params[name] = params[name]
evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
evaluator.seed(seed)
# Run evaluation.
evaluator.clear_history()
for _ in range(n_test_rollouts):
evaluator.generate_rollouts()
# record logs
for key, val in evaluator.logs('test'):
logger.record_tabular(key, np.mean(val))
logger.dump_tabular()
示例2: main
# 需要导入模块: from baselines.her import rollout [as 别名]
# 或者: from baselines.her.rollout import RolloutWorker [as 别名]
def main(policy_file, seed, n_test_rollouts, render):
set_global_seeds(seed)
# Load policy.
with open(policy_file, 'rb') as f:
policy = pickle.load(f)
env_name = policy.info['env_name']
# Prepare params.
params = config.DEFAULT_PARAMS
if env_name in config.DEFAULT_ENV_PARAMS:
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
params['env_name'] = env_name
params = config.prepare_params(params)
config.log_params(params, logger=logger)
dims = config.configure_dims(params)
eval_params = {
'exploit': True,
'use_target_net': params['test_with_polyak'],
'compute_Q': True,
'rollout_batch_size': 1,
'render': bool(render),
}
for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
eval_params[name] = params[name]
evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
evaluator.seed(seed)
# Run evaluation.
evaluator.clear_history()
for _ in range(n_test_rollouts):
evaluator.generate_rollouts()
# record logs
for key, val in evaluator.logs('test'):
logger.record_tabular(key, np.mean(val))
logger.dump_tabular()