本文整理匯總了Python中baselines.gail.trpo_mpi.learn方法的典型用法代碼示例。如果您正苦於以下問題:Python trpo_mpi.learn方法的具體用法?Python trpo_mpi.learn怎麽用?Python trpo_mpi.learn使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類baselines.gail.trpo_mpi
的用法示例。
在下文中一共展示了trpo_mpi.learn方法的1個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: train
# 需要導入模塊: from baselines.gail import trpo_mpi [as 別名]
# 或者: from baselines.gail.trpo_mpi import learn [as 別名]
def train(env, seed, policy_fn, reward_giver, dataset, algo,
g_step, d_step, policy_entcoeff, num_timesteps, save_per_iter,
checkpoint_dir, log_dir, pretrained, BC_max_iter, task_name=None):
pretrained_weight = None
if pretrained and (BC_max_iter > 0):
# Pretrain with behavior cloning
from baselines.gail import behavior_clone
pretrained_weight = behavior_clone.learn(env, policy_fn, dataset,
max_iters=BC_max_iter)
if algo == 'trpo':
from baselines.gail import trpo_mpi
# Set up for MPI seed
rank = MPI.COMM_WORLD.Get_rank()
if rank != 0:
logger.set_level(logger.DISABLED)
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env.seed(workerseed)
trpo_mpi.learn(env, policy_fn, reward_giver, dataset, rank,
pretrained=pretrained, pretrained_weight=pretrained_weight,
g_step=g_step, d_step=d_step,
entcoeff=policy_entcoeff,
max_timesteps=num_timesteps,
ckpt_dir=checkpoint_dir, log_dir=log_dir,
save_per_iter=save_per_iter,
timesteps_per_batch=1024,
max_kl=0.01, cg_iters=10, cg_damping=0.1,
gamma=0.995, lam=0.97,
vf_iters=5, vf_stepsize=1e-3,
task_name=task_name)
else:
raise NotImplementedError