本文整理匯總了Python中stable_baselines.A2C屬性的典型用法代碼示例。如果您正苦於以下問題:Python stable_baselines.A2C屬性的具體用法?Python stable_baselines.A2C怎麽用?Python stable_baselines.A2C使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在類stable_baselines
的用法示例。
在下文中一共展示了stable_baselines.A2C屬性的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: test_evaluate_policy
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import A2C [as 別名]
def test_evaluate_policy():
model = A2C('MlpPolicy', 'Pendulum-v0', seed=0)
n_steps_per_episode, n_eval_episodes = 200, 2
model.n_callback_calls = 0
def dummy_callback(locals_, _globals):
locals_['model'].n_callback_calls += 1
_, episode_lengths = evaluate_policy(model, model.get_env(), n_eval_episodes, deterministic=True,
render=False, callback=dummy_callback, reward_threshold=None,
return_episode_rewards=True)
n_steps = sum(episode_lengths)
assert n_steps == n_steps_per_episode * n_eval_episodes
assert n_steps == model.n_callback_calls
# Reaching a mean reward of zero is impossible with the Pendulum env
with pytest.raises(AssertionError):
evaluate_policy(model, model.get_env(), n_eval_episodes, reward_threshold=0.0)
episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
assert len(episode_rewards) == n_eval_episodes
示例2: train
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import A2C [as 別名]
def train(env_id, num_timesteps, seed, policy, lr_schedule, num_env):
"""
Train A2C model for atari environment, for testing purposes
:param env_id: (str) Environment ID
:param num_timesteps: (int) The total number of samples
:param seed: (int) The initial seed for training
:param policy: (A2CPolicy) The policy model to use (MLP, CNN, LSTM, ...)
:param lr_schedule: (str) The type of scheduler for the learning rate update ('linear', 'constant',
'double_linear_con', 'middle_drop' or 'double_middle_drop')
:param num_env: (int) The number of environments
"""
policy_fn = None
if policy == 'cnn':
policy_fn = CnnPolicy
elif policy == 'lstm':
policy_fn = CnnLstmPolicy
elif policy == 'lnlstm':
policy_fn = CnnLnLstmPolicy
if policy_fn is None:
raise ValueError("Error: policy {} not implemented".format(policy))
env = VecFrameStack(make_atari_env(env_id, num_env, seed), 4)
model = A2C(policy_fn, env, lr_schedule=lr_schedule, seed=seed)
model.learn(total_timesteps=int(num_timesteps * 1.1))
env.close()
示例3: __init__
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import A2C [as 別名]
def __init__(self):
super(A2CModel, self).__init__(name="a2c", model_class=A2C)