本文整理匯總了Python中stable_baselines.PPO2屬性的典型用法代碼示例。如果您正苦於以下問題:Python stable_baselines.PPO2屬性的具體用法?Python stable_baselines.PPO2怎麽用?Python stable_baselines.PPO2使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在類stable_baselines
的用法示例。
在下文中一共展示了stable_baselines.PPO2屬性的9個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: init_rl
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def init_rl(
env: Union[gym.Env, VecEnv],
model_class: Type[BaseRLModel] = stable_baselines.PPO2,
policy_class: Type[BasePolicy] = MlpPolicy,
**model_kwargs,
):
"""Instantiates a policy for the provided environment.
Args:
env: The (vector) environment.
model_class: A Stable Baselines RL algorithm.
policy_class: A Stable Baselines compatible policy network class.
model_kwargs (dict): kwargs passed through to the algorithm.
Note: anything specified in `policy_kwargs` is passed through by the
algorithm to the policy network.
Returns:
An RL algorithm.
"""
return model_class(
policy_class, env, **model_kwargs
) # pytype: disable=not-instantiable
示例2: test_lstm_policy
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def test_lstm_policy(request, model_class, policy):
model_fname = './test_model_{}.zip'.format(request.node.name)
try:
# create and train
if model_class == PPO2:
model = model_class(policy, 'CartPole-v1', nminibatches=1)
else:
model = model_class(policy, 'CartPole-v1')
model.learn(total_timesteps=100)
env = model.get_env()
evaluate_policy(model, env, n_eval_episodes=10)
# saving
model.save(model_fname)
del model, env
# loading
_ = model_class.load(model_fname, policy=policy)
finally:
if os.path.exists(model_fname):
os.remove(model_fname)
示例3: __init__
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def __init__(self,
model: BaseRLModel = PPO2,
policy: BasePolicy = MlpLnLstmPolicy,
reward_strategy: BaseRewardStrategy = IncrementalProfit,
exchange_args: Dict = {},
**kwargs):
self.logger = kwargs.get('logger', init_logger(__name__, show_debug=kwargs.get('show_debug', True)))
self.Model = model
self.Policy = policy
self.Reward_Strategy = reward_strategy
self.exchange_args = exchange_args
self.tensorboard_path = kwargs.get('tensorboard_path', None)
self.input_data_path = kwargs.get('input_data_path', 'data/input/coinbase-1h-btc-usd.csv')
self.params_db_path = kwargs.get('params_db_path', 'sqlite:///data/params.db')
self.date_format = kwargs.get('date_format', ProviderDateFormat.DATETIME_HOUR_24)
self.model_verbose = kwargs.get('model_verbose', 1)
self.n_envs = kwargs.get('n_envs', os.cpu_count())
self.n_minibatches = kwargs.get('n_minibatches', self.n_envs)
self.train_split_percentage = kwargs.get('train_split_percentage', 0.8)
self.data_provider = kwargs.get('data_provider', 'static')
self.initialize_data()
self.initialize_optuna()
self.logger.debug(f'Initialize RLTrader: {self.study_name}')
示例4: optimize_agent_params
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def optimize_agent_params(self, trial):
if self.Model != PPO2:
return {'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1.)}
return {
'n_steps': int(trial.suggest_loguniform('n_steps', 16, 2048)),
'gamma': trial.suggest_loguniform('gamma', 0.9, 0.9999),
'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1.),
'ent_coef': trial.suggest_loguniform('ent_coef', 1e-8, 1e-1),
'cliprange': trial.suggest_uniform('cliprange', 0.1, 0.4),
'noptepochs': int(trial.suggest_loguniform('noptepochs', 1, 48)),
'lam': trial.suggest_uniform('lam', 0.8, 1.)
}
示例5: ppo2
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def ppo2(batch_size, num_env, learning_rate, **kwargs):
return _stable(
stable_baselines.PPO2,
our_type="ppo2",
callback_key="update",
callback_mul=batch_size,
n_steps=batch_size // num_env,
learning_rate=learning_rate,
**kwargs,
)
示例6: train
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def train(env_id, num_timesteps, seed, policy,
n_envs=8, nminibatches=4, n_steps=128):
"""
Train PPO2 model for atari environment, for testing purposes
:param env_id: (str) the environment id string
:param num_timesteps: (int) the number of timesteps to run
:param seed: (int) Used to seed the random generator.
:param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...)
:param n_envs: (int) Number of parallel environments
:param nminibatches: (int) Number of training minibatches per update. For recurrent policies,
the number of environments run in parallel should be a multiple of nminibatches.
:param n_steps: (int) The number of steps to run for each environment per update
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
"""
env = VecFrameStack(make_atari_env(env_id, n_envs, seed), 4)
policy = {'cnn': CnnPolicy, 'lstm': CnnLstmPolicy, 'lnlstm': CnnLnLstmPolicy, 'mlp': MlpPolicy}[policy]
model = PPO2(policy=policy, env=env, n_steps=n_steps, nminibatches=nminibatches,
lam=0.95, gamma=0.99, noptepochs=4, ent_coef=.01,
learning_rate=lambda f: f * 2.5e-4, cliprange=lambda f: f * 0.1, verbose=1)
model.learn(total_timesteps=num_timesteps)
env.close()
# Free memory
del model
示例7: test_lstm_train
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def test_lstm_train():
"""Test that LSTM models are able to achieve >=150 (out of 500) reward on CartPoleNoVelEnv.
This environment requires memory to perform well in."""
def make_env(i):
env = CartPoleNoVelEnv()
env = TimeLimit(env, max_episode_steps=500)
env = bench.Monitor(env, None, allow_early_resets=True)
env.seed(i)
return env
env = SubprocVecEnv([lambda: make_env(i) for i in range(NUM_ENVS)])
env = VecNormalize(env)
model = PPO2(MlpLstmPolicy, env, n_steps=128, nminibatches=NUM_ENVS, lam=0.95, gamma=0.99,
noptepochs=10, ent_coef=0.0, learning_rate=3e-4, cliprange=0.2, verbose=1)
eprewmeans = []
def reward_callback(local, _):
nonlocal eprewmeans
eprewmeans.append(safe_mean([ep_info['r'] for ep_info in local['ep_info_buf']]))
model.learn(total_timesteps=100000, callback=reward_callback)
# Maximum episode reward is 500.
# In CartPole-v1, a non-recurrent policy can easily get >= 450.
# In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50.
# LSTM policies can reach above 400, but it varies a lot between runs; consistently get >=150.
# See PR #244 for more detailed benchmarks.
average_reward = sum(eprewmeans[-NUM_EPISODES_FOR_SCORE:]) / NUM_EPISODES_FOR_SCORE
assert average_reward >= 150, "Mean reward below 150; per-episode rewards {}".format(average_reward)
示例8: __init__
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def __init__(self):
super(PPO2Model, self).__init__(name="ppo2", model_class=PPO2)
示例9: load_old_ppo2
# 需要導入模塊: import stable_baselines [as 別名]
# 或者: from stable_baselines import PPO2 [as 別名]
def load_old_ppo2(root_dir, env, env_name, index, transparent_params):
try:
from baselines.ppo2 import ppo2 as ppo2_old
except ImportError as e:
msg = "{}. HINT: you need to install (OpenAI) Baselines to use old_ppo2".format(e)
raise ImportError(msg)
denv = FakeSingleSpacesVec(env, agent_id=index)
possible_fnames = ["model.pkl", "final_model.pkl"]
model_path = None
for fname in possible_fnames:
candidate_path = os.path.join(root_dir, fname)
if os.path.exists(candidate_path):
model_path = candidate_path
if model_path is None:
raise FileNotFoundError(
f"Could not find model at '{root_dir}' " f"under any filename '{possible_fnames}'"
)
graph = tf.Graph()
sess = tf.Session(graph=graph)
with sess.as_default():
with graph.as_default():
pylog.info(f"Loading Baselines PPO2 policy from '{model_path}'")
policy = ppo2_old.learn(
network="mlp",
env=denv,
total_timesteps=1,
seed=0,
nminibatches=4,
log_interval=1,
save_interval=1,
load_path=model_path,
)
stable_policy = OpenAIToStablePolicy(
policy, ob_space=denv.observation_space, ac_space=denv.action_space
)
model = PolicyToModel(stable_policy)
try:
normalize_path = os.path.join(root_dir, "normalize.pkl")
with open(normalize_path, "rb") as f:
old_vec_normalize = pickle.load(f)
vec_normalize = vec_env.VecNormalize(denv, training=False)
vec_normalize.obs_rms = old_vec_normalize.ob_rms
vec_normalize.ret_rms = old_vec_normalize.ret_rms
model = NormalizeModel(model, vec_normalize)
pylog.info(f"Loaded normalization statistics from '{normalize_path}'")
except FileNotFoundError:
# We did not use VecNormalize during training, skip
pass
return model