本文整理汇总了Python中baselines.common.atari_wrappers.wrap_deepmind方法的典型用法代码示例。如果您正苦于以下问题:Python atari_wrappers.wrap_deepmind方法的具体用法?Python atari_wrappers.wrap_deepmind怎么用?Python atari_wrappers.wrap_deepmind使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类baselines.common.atari_wrappers
的用法示例。
在下文中一共展示了atari_wrappers.wrap_deepmind方法的13个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: make_env
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
if env_type == 'atari':
env = make_atari(env_id)
elif env_type == 'retro':
import retro
gamestate = gamestate or retro.State.DEFAULT
env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
else:
env = gym.make(env_id)
env.seed(seed + subrank if seed is not None else None)
env = Monitor(env,
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
allow_early_resets=True)
if env_type == 'atari':
env = wrap_deepmind(env, **wrapper_kwargs)
elif env_type == 'retro':
env = retro_wrappers.wrap_deepmind_retro(env, **wrapper_kwargs)
if reward_scale != 1:
env = retro_wrappers.RewardScaler(env, reward_scale)
return env
示例2: make_atari_env
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def make_atari_env(env_id, num_env, seed, hparams=None, wrapper_kwargs=None, start_index=0, nsteps=5, **kwargs):
"""
Create a wrapped, monitored SubprocVecEnv for Atari.
"""
if wrapper_kwargs is None: wrapper_kwargs = {}
def make_env(rank): # pylint: disable=C0111
def _thunk():
env = make_atari(env_id)
env.seed(seed + rank)
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
if rank == start_index and 'video_log_dir' in kwargs:
env = VideoLogMonitor(env, kwargs['video_log_dir'] + '_rgb', write_attention_video=kwargs['write_attention_video'], hparams=hparams, nsteps=nsteps)
return wrap_deepmind(env, **wrapper_kwargs)
return _thunk
set_global_seeds(seed)
env_fns = [make_env(i + start_index) for i in range(num_env)]
global my_subproc_vec_env
assert my_subproc_vec_env == None
my_subproc_vec_env = SubprocVecEnv(env_fns)
return my_subproc_vec_env
示例3: train
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
def make_env(rank):
def _thunk():
env = make_atari(env_id)
env.seed(seed + rank)
env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
gym.logger.setLevel(logging.WARN)
return wrap_deepmind(env)
return _thunk
set_global_seeds(seed)
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
if policy == 'cnn':
policy_fn = CnnPolicy
elif policy == 'lstm':
policy_fn = LstmPolicy
elif policy == 'lnlstm':
policy_fn = LnLstmPolicy
learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
env.close()
示例4: train
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def train(env_id, num_timesteps, seed, policy, lrschedule, num_cpu):
def make_env(rank):
def _thunk():
env = make_atari(env_id)
env.seed(seed + rank)
env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
gym.logger.setLevel(logging.WARN)
return wrap_deepmind(env)
return _thunk
set_global_seeds(seed)
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
if policy == 'cnn':
policy_fn = AcerCnnPolicy
elif policy == 'lstm':
policy_fn = AcerLstmPolicy
else:
print("Policy {} not implemented".format(policy))
return
learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), lrschedule=lrschedule)
env.close()
示例5: train
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def train(env_id, num_frames, seed, policy, lrschedule, num_cpu):
num_timesteps = int(num_frames / 4 * 1.1)
# divide by 4 due to frameskip, then do a little extras so episodes end
def make_env(rank):
def _thunk():
env = gym.make(env_id)
env.seed(seed + rank)
env = bench.Monitor(env, logger.get_dir() and
os.path.join(logger.get_dir(), "{}.monitor.json".format(rank)))
gym.logger.setLevel(logging.WARN)
return wrap_deepmind(env)
return _thunk
set_global_seeds(seed)
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
if policy == 'cnn':
policy_fn = CnnPolicy
elif policy == 'linear':
policy_fn = LinearPolicy
elif policy == 'lstm':
policy_fn = LstmPolicy
elif policy == 'lnlstm':
policy_fn = LnLstmPolicy
learn(policy_fn, env, seed, total_timesteps=num_timesteps, lrschedule=lrschedule)
env.close()
示例6: wrap_atari_dqn
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def wrap_atari_dqn(env):
from baselines.common.atari_wrappers import wrap_deepmind
return wrap_deepmind(env, frame_stack=True, scale=True)
示例7: train
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def train(env_id, num_timesteps, seed):
from baselines.trpo_mpi.nosharing_cnn_policy import CnnPolicy
from baselines.trpo_mpi import trpo_mpi
import baselines.common.tf_util as U
rank = MPI.COMM_WORLD.Get_rank()
sess = U.single_threaded_session()
sess.__enter__()
if rank == 0:
logger.configure()
else:
logger.configure(format_strs=[])
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env = make_atari(env_id)
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
return CnnPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space)
env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
env.seed(workerseed)
env = wrap_deepmind(env)
env.seed(workerseed)
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3,
max_timesteps=int(num_timesteps * 1.1), gamma=0.98, lam=1.0, vf_iters=3, vf_stepsize=1e-4, entcoeff=0.00)
env.close()
示例8: train
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def train(env_id, num_timesteps, seed):
from baselines.ppo1 import pposgd_simple, cnn_policy
import baselines.common.tf_util as U
rank = MPI.COMM_WORLD.Get_rank()
sess = U.single_threaded_session()
sess.__enter__()
if rank == 0:
logger.configure()
else:
logger.configure(format_strs=[])
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env = make_atari(env_id)
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), str(rank)))
env.seed(workerseed)
env = wrap_deepmind(env)
env.seed(workerseed)
pposgd_simple.learn(env, policy_fn,
max_timesteps=int(num_timesteps * 1.1),
timesteps_per_actorbatch=256,
clip_param=0.2, entcoeff=0.01,
optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
gamma=0.99, lam=0.95,
schedule='linear'
)
env.close()
示例9: make_atari_env
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
"""
Create a wrapped, monitored SubprocVecEnv for Atari.
"""
if wrapper_kwargs is None: wrapper_kwargs = {}
def make_env(rank): # pylint: disable=C0111
def _thunk():
env = make_atari(env_id)
env.seed(seed + rank)
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
return wrap_deepmind(env, **wrapper_kwargs)
return _thunk
set_global_seeds(seed)
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
示例10: make_env
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def make_env(env_name, rank, seed):
env = make_atari(env_name)
env.seed(seed + rank)
env = wrap_deepmind(env)
return env
示例11: train
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def train(env_id, num_timesteps, seed):
from baselines.ppo1 import pposgd_simple, cnn_policy
import baselines.common.tf_util as U
rank = MPI.COMM_WORLD.Get_rank()
sess = U.single_threaded_session()
sess.__enter__()
if rank == 0:
logger.configure()
else:
logger.configure(format_strs=[])
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() if seed is not None else None
set_global_seeds(workerseed)
env = make_atari(env_id)
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
return cnn_policy.CnnPolicy(name=name, ob_space=ob_space, ac_space=ac_space)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), str(rank)))
env.seed(workerseed)
env = wrap_deepmind(env)
env.seed(workerseed)
pposgd_simple.learn(env, policy_fn,
max_timesteps=int(num_timesteps * 1.1),
timesteps_per_actorbatch=256,
clip_param=0.2, entcoeff=0.01,
optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64,
gamma=0.99, lam=0.95,
schedule='linear'
)
env.close()
示例12: make_atari_env
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
"""
Create a wrapped, monitored SubprocVecEnv for Atari.
"""
if wrapper_kwargs is None: wrapper_kwargs = {}
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
def make_env(rank): # pylint: disable=C0111
def _thunk():
env = make_atari(env_id)
env.seed(seed + 10000*mpi_rank + rank if seed is not None else None)
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)))
return wrap_deepmind(env, **wrapper_kwargs)
return _thunk
set_global_seeds(seed)
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
示例13: make_env
# 需要导入模块: from baselines.common import atari_wrappers [as 别名]
# 或者: from baselines.common.atari_wrappers import wrap_deepmind [as 别名]
def make_env(env_id, seed, rank, episode_life=True):
def _thunk():
random_seed(seed)
if env_id.startswith("dm"):
import dm_control2gym
_, domain, task = env_id.split('-')
env = dm_control2gym.make(domain_name=domain, task_name=task)
else:
env = gym.make(env_id)
is_atari = hasattr(gym.envs, 'atari') and isinstance(
env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
if is_atari:
env = make_atari(env_id)
env.seed(seed + rank)
env = OriginalReturnWrapper(env)
if is_atari:
env = wrap_deepmind(env,
episode_life=episode_life,
clip_rewards=False,
frame_stack=False,
scale=False)
obs_shape = env.observation_space.shape
if len(obs_shape) == 3:
env = TransposeImage(env)
env = FrameStack(env, 4)
return env
return _thunk