當前位置: 首頁>>代碼示例>>Python>>正文


Python policies.MlpPolicy方法代碼示例

本文整理匯總了Python中stable_baselines.common.policies.MlpPolicy方法的典型用法代碼示例。如果您正苦於以下問題:Python policies.MlpPolicy方法的具體用法?Python policies.MlpPolicy怎麽用?Python policies.MlpPolicy使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在stable_baselines.common.policies的用法示例。


在下文中一共展示了policies.MlpPolicy方法的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: init_rl

# 需要導入模塊: from stable_baselines.common import policies [as 別名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 別名]
def init_rl(
    env: Union[gym.Env, VecEnv],
    model_class: Type[BaseRLModel] = stable_baselines.PPO2,
    policy_class: Type[BasePolicy] = MlpPolicy,
    **model_kwargs,
):
    """Instantiates a policy for the provided environment.

    Args:
        env: The (vector) environment.
        model_class: A Stable Baselines RL algorithm.
        policy_class: A Stable Baselines compatible policy network class.
        model_kwargs (dict): kwargs passed through to the algorithm.
          Note: anything specified in `policy_kwargs` is passed through by the
          algorithm to the policy network.

    Returns:
      An RL algorithm.
    """
    return model_class(
        policy_class, env, **model_kwargs
    )  # pytype: disable=not-instantiable 
開發者ID:HumanCompatibleAI,項目名稱:imitation,代碼行數:24,代碼來源:util.py

示例2: train

# 需要導入模塊: from stable_baselines.common import policies [as 別名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 別名]
def train(env_id, num_timesteps, seed):
    """
    Train TRPO model for the mujoco environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    """
    with tf_util.single_threaded_session():
        rank = MPI.COMM_WORLD.Get_rank()
        if rank == 0:
            logger.configure()
        else:
            logger.configure(format_strs=[])
            logger.set_level(logger.DISABLED)
        workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()

        env = make_mujoco_env(env_id, workerseed)
        model = TRPO(MlpPolicy, env, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1, entcoeff=0.0,
                     gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
        model.learn(total_timesteps=num_timesteps)
        env.close() 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:24,代碼來源:run_mujoco.py

示例3: train

# 需要導入模塊: from stable_baselines.common import policies [as 別名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 別名]
def train(env_id, num_timesteps, seed):
    """
    Train PPO1 model for Robotics environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    """

    rank = MPI.COMM_WORLD.Get_rank()
    with mujoco_py.ignore_mujoco_warnings():
        workerseed = seed + 10000 * rank
        set_global_seeds(workerseed)
        env = make_robotics_env(env_id, workerseed, rank=rank)

        model = PPO1(MlpPolicy, env, timesteps_per_actorbatch=2048, clip_param=0.2, entcoeff=0.0, optim_epochs=5,
                     optim_stepsize=3e-4, optim_batchsize=256, gamma=0.99, lam=0.95, schedule='linear')
        model.learn(total_timesteps=num_timesteps)
        env.close() 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:21,代碼來源:run_robotics.py

示例4: swimmer

# 需要導入模塊: from stable_baselines.common import policies [as 別名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 別名]
def swimmer():
    locals().update(**MUJOCO_SHARED_LOCALS)
    env_name = "Swimmer-v2"
    rollout_hint = "swimmer"
    total_timesteps = 2e6
    init_rl_kwargs = dict(policy_network_class=policies.MlpPolicy) 
開發者ID:HumanCompatibleAI,項目名稱:imitation,代碼行數:8,代碼來源:train_adversarial.py

示例5: train

# 需要導入模塊: from stable_baselines.common import policies [as 別名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 別名]
def train(env_id, num_timesteps, seed, policy,
          n_envs=8, nminibatches=4, n_steps=128):
    """
    Train PPO2 model for atari environment, for testing purposes

    :param env_id: (str) the environment id string
    :param num_timesteps: (int) the number of timesteps to run
    :param seed: (int) Used to seed the random generator.
    :param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...)
    :param n_envs: (int) Number of parallel environments
    :param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
        the number of environments run in parallel should be a multiple of nminibatches.
    :param n_steps: (int) The number of steps to run for each environment per update
        (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
    """

    env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4)
    policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy]
    model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches,
                 lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01,
                 learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1)
    model.learn(total_timesteps=num_timesteps)

    env.close()
    # Free memory
    del model 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:28,代碼來源:run_atari.py

示例6: train

# 需要導入模塊: from stable_baselines.common import policies [as 別名]
# 或者: from stable_baselines.common.policies import MlpPolicy [as 別名]
def train(env_id, num_timesteps, seed):
    """
    Train PPO1 model for the Mujoco environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    """
    env = make_mujoco_env(env_id, seed)
    model = PPO1(MlpPolicy, env, timesteps_per_actorbatch=2048, clip_param=0.2, entcoeff=0.0, optim_epochs=10,
                 optim_stepsize=3e-4, optim_batchsize=64, gamma=0.99, lam=0.95, schedule='linear')
    model.learn(total_timesteps=num_timesteps)
    env.close() 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:15,代碼來源:run_mujoco.py


注:本文中的stable_baselines.common.policies.MlpPolicy方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。