本文整理匯總了Python中stable_baselines.DQN屬性的典型用法代碼示例。如果您正苦於以下問題:Python stable_baselines.DQN屬性的具體用法?Python stable_baselines.DQN怎麽用?Python stable_baselines.DQN使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在類stable_baselines
的用法示例。
在下文中一共展示了stable_baselines.DQN屬性的6個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: test_deterministic_training_common
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def test_deterministic_training_common(algo):
results = [[], []]
rewards = [[], []]
kwargs = {'n_cpu_tf_sess': 1}
if algo in [DDPG, TD3, SAC]:
env_id = 'Pendulum-v0'
kwargs.update({'action_noise': NormalActionNoise(0.0, 0.1)})
else:
env_id = 'CartPole-v1'
if algo == DQN:
kwargs.update({'learning_starts': 100})
for i in range(2):
model = algo('MlpPolicy', env_id, seed=SEED, **kwargs)
model.learn(N_STEPS_TRAINING)
env = model.get_env()
obs = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=False)
obs, reward, _, _ = env.step(action)
results[i].append(action)
rewards[i].append(reward)
assert sum(results[0]) == sum(results[1]), results
assert sum(rewards[0]) == sum(rewards[1]), rewards
示例2: test_long_episode
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def test_long_episode(model_class):
"""
Check that the model does not break when the replay buffer is still empty
after the first rollout (because the episode is not over).
"""
# n_bits > nb_rollout_steps
n_bits = 10
env = BitFlippingEnv(n_bits, continuous=model_class in [DDPG, SAC, TD3],
max_steps=n_bits)
kwargs = {}
if model_class == DDPG:
kwargs['nb_rollout_steps'] = 9 # < n_bits
elif model_class in [DQN, SAC, TD3]:
kwargs['batch_size'] = 8 # < n_bits
kwargs['learning_starts'] = 0
model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy='future',
verbose=0, **kwargs)
model.learn(200)
示例3: test_offpolicy_normalization
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def test_offpolicy_normalization(model_class):
if model_class == DQN:
env = DummyVecEnv([lambda: gym.make('CartPole-v1')])
else:
env = DummyVecEnv([make_env])
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
model = model_class('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=1000)
# Check getter
assert isinstance(model.get_vec_normalize_env(), VecNormalize)
示例4: __init__
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def __init__(self):
super(DQNModel, self).__init__(name="deepq", model_class=DQN)
示例5: makeEnv
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def makeEnv(cls, args, env_kwargs=None, load_path_normalise=None):
# Even though DQN is single core only, we need to use the pipe system to work
if env_kwargs is not None and env_kwargs.get("use_srl", False):
srl_model = MultiprocessSRLModel(1, args.env, env_kwargs)
env_kwargs["state_dim"] = srl_model.state_dim
env_kwargs["srl_pipe"] = srl_model.pipe
env = DummyVecEnv([makeEnv(args.env, args.seed, 0, args.log_dir, env_kwargs=env_kwargs)])
if args.srl_model != "raw_pixels":
env = VecNormalize(env, norm_reward=False)
env = loadRunningAverage(env, load_path_normalise=load_path_normalise)
return env
示例6: test_callbacks
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import DQN [as 別名]
def test_callbacks(model_class):
env_id = 'Pendulum-v0'
if model_class in [ACER, DQN]:
env_id = 'CartPole-v1'
allowed_failures = []
# Number of training timesteps is too short
# otherwise, the training would take too long, or would require
# custom parameter per algorithm
if model_class in [PPO1, DQN, TRPO]:
allowed_failures = ['rollout_end']
# Create RL model
model = model_class('MlpPolicy', env_id)
checkpoint_callback = CheckpointCallback(save_freq=500, save_path=LOG_FOLDER)
# For testing: use the same training env
eval_env = model.get_env()
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
best_model_save_path=LOG_FOLDER,
log_path=LOG_FOLDER, eval_freq=100)
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=LOG_FOLDER,
name_prefix='event')
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
model.learn(500, callback=callback)
model.learn(200, callback=None)
custom_callback = CustomCallback()
model.learn(200, callback=custom_callback)
# Check that every called were executed
custom_callback.validate(allowed_failures=allowed_failures)
# Transform callback into a callback list automatically
custom_callback = CustomCallback()
model.learn(500, callback=[checkpoint_callback, eval_callback, custom_callback])
# Check that every called were executed
custom_callback.validate(allowed_failures=allowed_failures)
# Automatic wrapping, old way of doing callbacks
model.learn(200, callback=lambda _locals, _globals: True)
# Cleanup
if os.path.exists(LOG_FOLDER):
shutil.rmtree(LOG_FOLDER)