当前位置: 首页>>代码示例>>Python>>正文


Python policies.MlpPolicy方法代码示例

本文整理汇总了Python中stable_baselines.common.policies.MlpPolicy方法的典型用法代码示例。如果您正苦于以下问题:Python policies.MlpPolicy方法的具体用法?Python policies.MlpPolicy怎么用?Python policies.MlpPolicy使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在stable_baselines.common.policies的用法示例。


在下文中一共展示了policies.MlpPolicy方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: init_rl

# 需要导入模块: from stable_baselines.common import policies [as 别名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 别名]
def init_rl(
    env: Union[gym.Env, VecEnv],
    model_class: Type[BaseRLModel] = stable_baselines.PPO2,
    policy_class: Type[BasePolicy] = MlpPolicy,
    **model_kwargs,
):
    """Instantiates a policy for the provided environment.

    Args:
        env: The (vector) environment.
        model_class: A Stable Baselines RL algorithm.
        policy_class: A Stable Baselines compatible policy network class.
        model_kwargs (dict): kwargs passed through to the algorithm.
          Note: anything specified in `policy_kwargs` is passed through by the
          algorithm to the policy network.

    Returns:
      An RL algorithm.
    """
    return model_class(
        policy_class, env, **model_kwargs
    )  # pytype: disable=not-instantiable 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:24,代码来源:util.py

示例2: train

# 需要导入模块: from stable_baselines.common import policies [as 别名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 别名]
def train(env_id, num_timesteps, seed):
    """
    Train TRPO model for the mujoco environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    """
    with tf_util.single_threaded_session():
        rank = MPI.COMM_WORLD.Get_rank()
        if rank == 0:
            logger.configure()
        else:
            logger.configure(format_strs=[])
            logger.set_level(logger.DISABLED)
        workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()

        env = make_mujoco_env(env_id, workerseed)
        model = TRPO(MlpPolicy, env, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1, entcoeff=0.0,
                     gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
        model.learn(total_timesteps=num_timesteps)
        env.close() 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:24,代码来源:run_mujoco.py

示例3: train

# 需要导入模块: from stable_baselines.common import policies [as 别名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 别名]
def train(env_id, num_timesteps, seed):
    """
    Train PPO1 model for Robotics environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    """

    rank = MPI.COMM_WORLD.Get_rank()
    with mujoco_py.ignore_mujoco_warnings():
        workerseed = seed + 10000 * rank
        set_global_seeds(workerseed)
        env = make_robotics_env(env_id, workerseed, rank=rank)

        model = PPO1(MlpPolicy, env, timesteps_per_actorbatch=2048, clip_param=0.2, entcoeff=0.0, optim_epochs=5,
                     optim_stepsize=3e-4, optim_batchsize=256, gamma=0.99, lam=0.95, schedule='linear')
        model.learn(total_timesteps=num_timesteps)
        env.close() 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:21,代码来源:run_robotics.py

示例4: swimmer

# 需要导入模块: from stable_baselines.common import policies [as 别名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 别名]
def swimmer():
    locals().update(**MUJOCO_SHARED_LOCALS)
    env_name = "Swimmer-v2"
    rollout_hint = "swimmer"
    total_timesteps = 2e6
    init_rl_kwargs = dict(policy_network_class=policies.MlpPolicy) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:8,代码来源:train_adversarial.py

示例5: train

# 需要导入模块: from stable_baselines.common import policies [as 别名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 别名]
def train(env_id, num_timesteps, seed, policy,
          n_envs=8, nminibatches=4, n_steps=128):
    """
    Train PPO2 model for atari environment, for testing purposes

    :param env_id: (str) the environment id string
    :param num_timesteps: (int) the number of timesteps to run
    :param seed: (int) Used to seed the random generator.
    :param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...)
    :param n_envs: (int) Number of parallel environments
    :param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
        the number of environments run in parallel should be a multiple of nminibatches.
    :param n_steps: (int) The number of steps to run for each environment per update
        (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
    """

    env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4)
    policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy]
    model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches,
                 lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01,
                 learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1)
    model.learn(total_timesteps=num_timesteps)

    env.close()
    # Free memory
    del model 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:28,代码来源:run_atari.py

示例6: train

# 需要导入模块: from stable_baselines.common import policies [as 别名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 别名]
def train(env_id, num_timesteps, seed):
    """
    Train PPO1 model for the Mujoco environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    """
    env = make_mujoco_env(env_id, seed)
    model = PPO1(MlpPolicy, env, timesteps_per_actorbatch=2048, clip_param=0.2, entcoeff=0.0, optim_epochs=10,
                 optim_stepsize=3e-4, optim_batchsize=64, gamma=0.99, lam=0.95, schedule='linear')
    model.learn(total_timesteps=num_timesteps)
    env.close() 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:15,代码来源:run_mujoco.py


注:本文中的stable_baselines.common.policies.MlpPolicy方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。