當前位置: 首頁>>代碼示例>>Python>>正文


Python stable_baselines.A2C屬性代碼示例

本文整理匯總了Python中stable_baselines.A2C屬性的典型用法代碼示例。如果您正苦於以下問題:Python stable_baselines.A2C屬性的具體用法?Python stable_baselines.A2C怎麽用?Python stable_baselines.A2C使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在stable_baselines的用法示例。


在下文中一共展示了stable_baselines.A2C屬性的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test_evaluate_policy

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import A2C [as 別名]
def test_evaluate_policy():
    model = A2C('MlpPolicy', 'Pendulum-v0', seed=0)
    n_steps_per_episode, n_eval_episodes = 200, 2
    model.n_callback_calls = 0

    def dummy_callback(locals_, _globals):
        locals_['model'].n_callback_calls += 1

    _, episode_lengths = evaluate_policy(model, model.get_env(), n_eval_episodes, deterministic=True,
                                         render=False, callback=dummy_callback, reward_threshold=None,
                                         return_episode_rewards=True)

    n_steps = sum(episode_lengths)
    assert n_steps == n_steps_per_episode * n_eval_episodes
    assert n_steps == model.n_callback_calls

    # Reaching a mean reward of zero is impossible with the Pendulum env
    with pytest.raises(AssertionError):
        evaluate_policy(model, model.get_env(), n_eval_episodes, reward_threshold=0.0)

    episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
    assert len(episode_rewards) == n_eval_episodes 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:24,代碼來源:test_utils.py

示例2: train

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import A2C [as 別名]
def train(env_id, num_timesteps, seed, policy, lr_schedule, num_env):
    """
    Train A2C model for atari environment, for testing purposes

    :param env_id: (str) Environment ID
    :param num_timesteps: (int) The total number of samples
    :param seed: (int) The initial seed for training
    :param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
    :param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
                                 'double_linear_con', 'middle_drop' or 'double_middle_drop')
    :param num_env: (int) The number of environments
    """
    policy_fn = None
    if policy == 'cnn':
        policy_fn = CnnPolicy
    elif policy == 'lstm':
        policy_fn = CnnLstmPolicy
    elif policy == 'lnlstm':
        policy_fn = CnnLnLstmPolicy
    if policy_fn is None:
        raise ValueError("Error: policy {} not implemented".format(policy))

    env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)

    model = A2C(policy_fn, env, lr_schedule=lr_schedule, seed=seed)
    model.learn(total_timesteps=int(num_timesteps * 1.1))
    env.close() 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:29,代碼來源:run_atari.py

示例3: __init__

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import A2C [as 別名]
def __init__(self):
        super(A2CModel, self).__init__(name="a2c", model_class=A2C) 
開發者ID:araffin,項目名稱:robotics-rl-srl,代碼行數:4,代碼來源:a2c.py


注:本文中的stable_baselines.A2C屬性示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。