當前位置: 首頁>>代碼示例>>Python>>正文


Python stable_baselines.DQN屬性代碼示例

本文整理匯總了Python中stable_baselines.DQN屬性的典型用法代碼示例。如果您正苦於以下問題:Python stable_baselines.DQN屬性的具體用法?Python stable_baselines.DQN怎麽用?Python stable_baselines.DQN使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在stable_baselines的用法示例。


在下文中一共展示了stable_baselines.DQN屬性的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test_deterministic_training_common

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def test_deterministic_training_common(algo):
    results = [[], []]
    rewards = [[], []]
    kwargs = {'n_cpu_tf_sess': 1}
    if algo in [DDPG, TD3, SAC]:
        env_id = 'Pendulum-v0'
        kwargs.update({'action_noise': NormalActionNoise(0.0, 0.1)})
    else:
        env_id = 'CartPole-v1'
        if algo == DQN:
            kwargs.update({'learning_starts': 100})

    for i in range(2):
        model = algo('MlpPolicy', env_id, seed=SEED, **kwargs)
        model.learn(N_STEPS_TRAINING)
        env = model.get_env()
        obs = env.reset()
        for _ in range(100):
            action, _ = model.predict(obs, deterministic=False)
            obs, reward, _, _ = env.step(action)
            results[i].append(action)
            rewards[i].append(reward)
    assert sum(results[0]) == sum(results[1]), results
    assert sum(rewards[0]) == sum(rewards[1]), rewards 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:26,代碼來源:test_0deterministic.py

示例2: test_long_episode

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def test_long_episode(model_class):
    """
    Check that the model does not break when the replay buffer is still empty
    after the first rollout (because the episode is not over).
    """
    # n_bits > nb_rollout_steps
    n_bits = 10
    env = BitFlippingEnv(n_bits, continuous=model_class in [DDPG, SAC, TD3],
                         max_steps=n_bits)
    kwargs = {}
    if model_class == DDPG:
        kwargs['nb_rollout_steps'] = 9  # < n_bits
    elif model_class in [DQN, SAC, TD3]:
        kwargs['batch_size'] = 8  # < n_bits
        kwargs['learning_starts'] = 0

    model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy='future',
                verbose=0, **kwargs)
    model.learn(200) 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:21,代碼來源:test_her.py

示例3: test_offpolicy_normalization

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def test_offpolicy_normalization(model_class):
    if model_class == DQN:
        env = DummyVecEnv([lambda: gym.make('CartPole-v1')])
    else:
        env = DummyVecEnv([make_env])
    env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)

    model = model_class('MlpPolicy', env, verbose=1)
    model.learn(total_timesteps=1000)
    # Check getter
    assert isinstance(model.get_vec_normalize_env(), VecNormalize) 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:13,代碼來源:test_vec_normalize.py

示例4: __init__

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def __init__(self):
        super(DQNModel, self).__init__(name="deepq", model_class=DQN) 
開發者ID:araffin,項目名稱:robotics-rl-srl,代碼行數:4,代碼來源:deepq.py

示例5: makeEnv

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def makeEnv(cls, args, env_kwargs=None, load_path_normalise=None):
        # Even though DQN is single core only, we need to use the pipe system to work
        if env_kwargs is not None and env_kwargs.get("use_srl", False):
            srl_model = MultiprocessSRLModel(1, args.env, env_kwargs)
            env_kwargs["state_dim"] = srl_model.state_dim
            env_kwargs["srl_pipe"] = srl_model.pipe

        env = DummyVecEnv([makeEnv(args.env, args.seed, 0, args.log_dir, env_kwargs=env_kwargs)])

        if args.srl_model != "raw_pixels":
            env = VecNormalize(env, norm_reward=False)
            env = loadRunningAverage(env, load_path_normalise=load_path_normalise)

        return env 
開發者ID:araffin,項目名稱:robotics-rl-srl,代碼行數:16,代碼來源:deepq.py

示例6: test_callbacks

# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def test_callbacks(model_class):

    env_id = 'Pendulum-v0'
    if model_class in [ACER, DQN]:
        env_id = 'CartPole-v1'

    allowed_failures = []
    # Number of training timesteps is too short
    # otherwise, the training would take too long, or would require
    # custom parameter per algorithm
    if model_class in [PPO1, DQN, TRPO]:
        allowed_failures = ['rollout_end']

    # Create RL model
    model = model_class('MlpPolicy', env_id)

    checkpoint_callback = CheckpointCallback(save_freq=500, save_path=LOG_FOLDER)

    # For testing: use the same training env
    eval_env = model.get_env()
    # Stop training if the performance is good enough
    callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)

    eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
                                 best_model_save_path=LOG_FOLDER,
                                 log_path=LOG_FOLDER, eval_freq=100)

    # Equivalent to the `checkpoint_callback`
    # but here in an event-driven manner
    checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=LOG_FOLDER,
                                             name_prefix='event')
    event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)

    callback = CallbackList([checkpoint_callback, eval_callback, event_callback])

    model.learn(500, callback=callback)
    model.learn(200, callback=None)
    custom_callback = CustomCallback()
    model.learn(200, callback=custom_callback)
    # Check that every called were executed
    custom_callback.validate(allowed_failures=allowed_failures)
    # Transform callback into a callback list automatically
    custom_callback = CustomCallback()
    model.learn(500, callback=[checkpoint_callback, eval_callback, custom_callback])
    # Check that every called were executed
    custom_callback.validate(allowed_failures=allowed_failures)

    # Automatic wrapping, old way of doing callbacks
    model.learn(200, callback=lambda _locals, _globals: True)

    # Cleanup
    if os.path.exists(LOG_FOLDER):
        shutil.rmtree(LOG_FOLDER) 
開發者ID:Stable-Baselines-Team,項目名稱:stable-baselines,代碼行數:55,代碼來源:test_callbacks.py


注:本文中的stable_baselines.DQN屬性示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。