本文整理汇总了Python中baselines.gail.dataset.mujoco_dset.Mujoco_Dset方法的典型用法代码示例。如果您正苦于以下问题:Python mujoco_dset.Mujoco_Dset方法的具体用法?Python mujoco_dset.Mujoco_Dset怎么用?Python mujoco_dset.Mujoco_Dset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类baselines.gail.dataset.mujoco_dset
的用法示例。
在下文中一共展示了mujoco_dset.Mujoco_Dset方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: load_dataset
# 需要导入模块: from baselines.gail.dataset import mujoco_dset [as 别名]
# 或者: from baselines.gail.dataset.mujoco_dset import Mujoco_Dset [as 别名]
def load_dataset(expert_path):
dataset = Mujoco_Dset(expert_path=expert_path)
return dataset
示例2: main
# 需要导入模块: from baselines.gail.dataset import mujoco_dset [as 别名]
# 或者: from baselines.gail.dataset.mujoco_dset import Mujoco_Dset [as 别名]
def main(args):
U.make_session(num_cpu=1).__enter__()
set_global_seeds(args.seed)
env = gym.make(args.env_id)
def policy_fn(name, ob_space, ac_space, reuse=False):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
reuse=reuse, hid_size=args.policy_hidden_size, num_hid_layers=2)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), "monitor.json"))
env.seed(args.seed)
gym.logger.setLevel(logging.WARN)
task_name = get_task_name(args)
args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
args.log_dir = osp.join(args.log_dir, task_name)
dataset = Mujoco_Dset(expert_path=args.expert_path, traj_limitation=args.traj_limitation)
savedir_fname = learn(env,
policy_fn,
dataset,
max_iters=args.BC_max_iter,
ckpt_dir=args.checkpoint_dir,
log_dir=args.log_dir,
task_name=task_name,
verbose=True)
avg_len, avg_ret = runner(env,
policy_fn,
savedir_fname,
timesteps_per_batch=1024,
number_trajs=10,
stochastic_policy=args.stochastic_policy,
save=args.save_sample,
reuse=True)
示例3: main
# 需要导入模块: from baselines.gail.dataset import mujoco_dset [as 别名]
# 或者: from baselines.gail.dataset.mujoco_dset import Mujoco_Dset [as 别名]
def main(args):
U.make_session(num_cpu=1).__enter__()
set_global_seeds(args.seed)
env = gym.make(args.env_id)
def policy_fn(name, ob_space, ac_space, reuse=False):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
reuse=reuse, hid_size=args.policy_hidden_size, num_hid_layers=2)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), "monitor.json"))
env.seed(args.seed)
gym.logger.setLevel(logging.WARN)
task_name = get_task_name(args)
args.checkpoint_dir = osp.join(args.checkpoint_dir, task_name)
args.log_dir = osp.join(args.log_dir, task_name)
if args.task == 'train':
dataset = Mujoco_Dset(expert_path=args.expert_path, traj_limitation=args.traj_limitation)
reward_giver = TransitionClassifier(env, args.adversary_hidden_size, entcoeff=args.adversary_entcoeff)
train(env,
args.seed,
policy_fn,
reward_giver,
dataset,
args.algo,
args.g_step,
args.d_step,
args.policy_entcoeff,
args.num_timesteps,
args.save_per_iter,
args.checkpoint_dir,
args.log_dir,
args.pretrained,
args.BC_max_iter,
task_name
)
elif args.task == 'evaluate':
runner(env,
policy_fn,
args.load_model_path,
timesteps_per_batch=1024,
number_trajs=10,
stochastic_policy=args.stochastic_policy,
save=args.save_sample
)
else:
raise NotImplementedError
env.close()