當前位置: 首頁>>代碼示例>>Python>>正文


Python stable_baselines.PPO2屬性代碼示例

本文整理匯總了Python中stable_baselines.PPO2屬性的典型用法代碼示例。如果您正苦於以下問題:Python stable_baselines.PPO2屬性的具體用法?Python stable_baselines.PPO2怎麽用?Python stable_baselines.PPO2使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在stable_baselines的用法示例。


在下文中一共展示了stable_baselines.PPO2屬性的9個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: init_rl

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def init_rl(
    env: Union[gym.Env, VecEnv],
    model_class: Type[BaseRLModel] = stable_baselines.PPO2,
    policy_class: Type[BasePolicy] = MlpPolicy,
    **model_kwargs,
):
    """Instantiates a policy for the provided environment.

    Args:
        env: The (vector) environment.
        model_class: A Stable Baselines RL algorithm.
        policy_class: A Stable Baselines compatible policy network class.
        model_kwargs (dict): kwargs passed through to the algorithm.
          Note: anything specified in `policy_kwargs` is passed through by the
          algorithm to the policy network.

    Returns:
      An RL algorithm.
    """
    return model_class(
        policy_class, env, **model_kwargs
    )  # pytype: disable=not-instantiable 
開發者ID:HumanCompatibleAI,項目名稱:imitation,代碼行數:24,代碼來源:util.py

示例2: test_lstm_policy

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def test_lstm_policy(request, model_class, policy):
    model_fname = './test_model_{}.zip'.format(request.node.name)

    try:
        # create and train
        if model_class == PPO2:
            model = model_class(policy, 'CartPole-v1', nminibatches=1)
        else:
            model = model_class(policy, 'CartPole-v1')
        model.learn(total_timesteps=100)

        env = model.get_env()
        evaluate_policy(model, env, n_eval_episodes=10)
        # saving
        model.save(model_fname)
        del model, env
        # loading
        _ = model_class.load(model_fname, policy=policy)

    finally:
        if os.path.exists(model_fname):
            os.remove(model_fname) 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:24,代碼來源:test_lstm_policy.py

示例3: __init__

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def __init__(self,
                 model: BaseRLModel = PPO2,
                 policy: BasePolicy = MlpLnLstmPolicy,
                 reward_strategy: BaseRewardStrategy = IncrementalProfit,
                 exchange_args: Dict = {},
                 **kwargs):
        self.logger = kwargs.get('logger', init_logger(__name__, show_debug=kwargs.get('show_debug', True)))

        self.Model = model
        self.Policy = policy
        self.Reward_Strategy = reward_strategy
        self.exchange_args = exchange_args
        self.tensorboard_path = kwargs.get('tensorboard_path', None)
        self.input_data_path = kwargs.get('input_data_path', 'data/input/coinbase-1h-btc-usd.csv')
        self.params_db_path = kwargs.get('params_db_path', 'sqlite:///data/params.db')

        self.date_format = kwargs.get('date_format', ProviderDateFormat.DATETIME_HOUR_24)

        self.model_verbose = kwargs.get('model_verbose', 1)
        self.n_envs = kwargs.get('n_envs', os.cpu_count())
        self.n_minibatches = kwargs.get('n_minibatches', self.n_envs)
        self.train_split_percentage = kwargs.get('train_split_percentage', 0.8)
        self.data_provider = kwargs.get('data_provider', 'static')

        self.initialize_data()
        self.initialize_optuna()

        self.logger.debug(f'Initialize RLTrader: {self.study_name}') 
開發者ID:notadamking,項目名稱:RLTrader,代碼行數:30,代碼來源:RLTrader.py

示例4: optimize_agent_params

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def optimize_agent_params(self, trial):
        if self.Model != PPO2:
            return {'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1.)}

        return {
            'n_steps': int(trial.suggest_loguniform('n_steps', 16, 2048)),
            'gamma': trial.suggest_loguniform('gamma', 0.9, 0.9999),
            'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1.),
            'ent_coef': trial.suggest_loguniform('ent_coef', 1e-8, 1e-1),
            'cliprange': trial.suggest_uniform('cliprange', 0.1, 0.4),
            'noptepochs': int(trial.suggest_loguniform('noptepochs', 1, 48)),
            'lam': trial.suggest_uniform('lam', 0.8, 1.)
        } 
開發者ID:notadamking,項目名稱:RLTrader,代碼行數:15,代碼來源:RLTrader.py

示例5: ppo2

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def ppo2(batch_size, num_env, learning_rate, **kwargs):
    return _stable(
        stable_baselines.PPO2,
        our_type="ppo2",
        callback_key="update",
        callback_mul=batch_size,
        n_steps=batch_size // num_env,
        learning_rate=learning_rate,
        **kwargs,
    ) 
開發者ID:HumanCompatibleAI,項目名稱:adversarial-policies,代碼行數:12,代碼來源:train.py

示例6: train

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def train(env_id, num_timesteps, seed, policy,
          n_envs=8, nminibatches=4, n_steps=128):
    """
    Train PPO2 model for atari environment, for testing purposes

    :param env_id: (str) the environment id string
    :param num_timesteps: (int) the number of timesteps to run
    :param seed: (int) Used to seed the random generator.
    :param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...)
    :param n_envs: (int) Number of parallel environments
    :param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
        the number of environments run in parallel should be a multiple of nminibatches.
    :param n_steps: (int) The number of steps to run for each environment per update
        (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
    """

    env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4)
    policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy]
    model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches,
                 lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01,
                 learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1)
    model.learn(total_timesteps=num_timesteps)

    env.close()
    # Free memory
    del model 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:28,代碼來源:run_atari.py

示例7: test_lstm_train

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def test_lstm_train():
    """Test that LSTM models are able to achieve >=150 (out of 500) reward on CartPoleNoVelEnv.

    This environment requires memory to perform well in."""
    def make_env(i):
        env = CartPoleNoVelEnv()
        env = TimeLimit(env, max_episode_steps=500)
        env = bench.Monitor(env, None, allow_early_resets=True)
        env.seed(i)
        return env

    env = SubprocVecEnv([lambda: make_env(i) for i in range(NUM_ENVS)])
    env = VecNormalize(env)
    model = PPO2(MlpLstmPolicy, env, n_steps=128, nminibatches=NUM_ENVS, lam=0.95, gamma=0.99,
                 noptepochs=10, ent_coef=0.0, learning_rate=3e-4, cliprange=0.2, verbose=1)

    eprewmeans = []
    def reward_callback(local, _):
        nonlocal eprewmeans
        eprewmeans.append(safe_mean([ep_info['r'] for ep_info in local['ep_info_buf']]))

    model.learn(total_timesteps=100000, callback=reward_callback)

    # Maximum episode reward is 500.
    # In CartPole-v1, a non-recurrent policy can easily get >= 450.
    # In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50.
    # LSTM policies can reach above 400, but it varies a lot between runs; consistently get >=150.
    # See PR #244 for more detailed benchmarks.

    average_reward = sum(eprewmeans[-NUM_EPISODES_FOR_SCORE:]) / NUM_EPISODES_FOR_SCORE
    assert average_reward >= 150, "Mean reward below 150; per-episode rewards {}".format(average_reward) 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:33,代碼來源:test_lstm_policy.py

示例8: __init__

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def __init__(self):
        super(PPO2Model, self).__init__(name="ppo2", model_class=PPO2) 
開發者ID:araffin,項目名稱:robotics-rl-srl,代碼行數:4,代碼來源:ppo2.py

示例9: load_old_ppo2

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def load_old_ppo2(root_dir, env, env_name, index, transparent_params):
    try:
        from baselines.ppo2 import ppo2 as ppo2_old
    except ImportError as e:
        msg = "{}. HINT: you need to install (OpenAI) Baselines to use old_ppo2".format(e)
        raise ImportError(msg)

    denv = FakeSingleSpacesVec(env, agent_id=index)
    possible_fnames = ["model.pkl", "final_model.pkl"]
    model_path = None
    for fname in possible_fnames:
        candidate_path = os.path.join(root_dir, fname)
        if os.path.exists(candidate_path):
            model_path = candidate_path
    if model_path is None:
        raise FileNotFoundError(
            f"Could not find model at '{root_dir}' " f"under any filename '{possible_fnames}'"
        )

    graph = tf.Graph()
    sess = tf.Session(graph=graph)
    with sess.as_default():
        with graph.as_default():
            pylog.info(f"Loading Baselines PPO2 policy from '{model_path}'")
            policy = ppo2_old.learn(
                network="mlp",
                env=denv,
                total_timesteps=1,
                seed=0,
                nminibatches=4,
                log_interval=1,
                save_interval=1,
                load_path=model_path,
            )
    stable_policy = OpenAIToStablePolicy(
        policy, ob_space=denv.observation_space, ac_space=denv.action_space
    )
    model = PolicyToModel(stable_policy)

    try:
        normalize_path = os.path.join(root_dir, "normalize.pkl")
        with open(normalize_path, "rb") as f:
            old_vec_normalize = pickle.load(f)
        vec_normalize = vec_env.VecNormalize(denv, training=False)
        vec_normalize.obs_rms = old_vec_normalize.ob_rms
        vec_normalize.ret_rms = old_vec_normalize.ret_rms
        model = NormalizeModel(model, vec_normalize)
        pylog.info(f"Loaded normalization statistics from '{normalize_path}'")
    except FileNotFoundError:
        # We did not use VecNormalize during training, skip
        pass

    return model 
開發者ID:HumanCompatibleAI,項目名稱:adversarial-policies,代碼行數:55,代碼來源:loader.py


注:本文中的stable_baselines.PPO2屬性示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。