本文整理汇总了Python中baselines.gail.trpo_mpi.learn方法的典型用法代码示例。如果您正苦于以下问题:Python trpo_mpi.learn方法的具体用法?Python trpo_mpi.learn怎么用?Python trpo_mpi.learn使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类baselines.gail.trpo_mpi
的用法示例。
在下文中一共展示了trpo_mpi.learn方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from baselines.gail import trpo_mpi [as 别名]
# 或者: from baselines.gail.trpo_mpi import learn [as 别名]
def train(env, seed, policy_fn, reward_giver, dataset, algo,
g_step, d_step, policy_entcoeff, num_timesteps, save_per_iter,
checkpoint_dir, log_dir, pretrained, BC_max_iter, task_name=None):
pretrained_weight = None
if pretrained and (BC_max_iter > 0):
# Pretrain with behavior cloning
from baselines.gail import behavior_clone
pretrained_weight = behavior_clone.learn(env, policy_fn, dataset,
max_iters=BC_max_iter)
if algo == 'trpo':
from baselines.gail import trpo_mpi
# Set up for MPI seed
rank = MPI.COMM_WORLD.Get_rank()
if rank != 0:
logger.set_level(logger.DISABLED)
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env.seed(workerseed)
trpo_mpi.learn(env, policy_fn, reward_giver, dataset, rank,
pretrained=pretrained, pretrained_weight=pretrained_weight,
g_step=g_step, d_step=d_step,
entcoeff=policy_entcoeff,
max_timesteps=num_timesteps,
ckpt_dir=checkpoint_dir, log_dir=log_dir,
save_per_iter=save_per_iter,
timesteps_per_batch=1024,
max_kl=0.01, cg_iters=10, cg_damping=0.1,
gamma=0.995, lam=0.97,
vf_iters=5, vf_stepsize=1e-3,
task_name=task_name)
else:
raise NotImplementedError