本文整理汇总了Python中baselines.trpo_mpi.trpo_mpi.learn方法的典型用法代码示例。如果您正苦于以下问题:Python trpo_mpi.learn方法的具体用法?Python trpo_mpi.learn怎么用?Python trpo_mpi.learn使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类baselines.trpo_mpi.trpo_mpi
的用法示例。
在下文中一共展示了trpo_mpi.learn方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from baselines.trpo_mpi import trpo_mpi [as 别名]
# 或者: from baselines.trpo_mpi.trpo_mpi import learn [as 别名]
def train(env_id, num_timesteps, seed):
import baselines.common.tf_util as U
sess = U.single_threaded_session()
sess.__enter__()
rank = MPI.COMM_WORLD.Get_rank()
if rank == 0:
logger.configure()
else:
logger.configure(format_strs=[])
logger.set_level(logger.DISABLED)
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
def policy_fn(name, ob_space, ac_space):
return MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
hid_size=32, num_hid_layers=2)
env = make_mujoco_env(env_id, workerseed)
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
env.close()
示例2: train
# 需要导入模块: from baselines.trpo_mpi import trpo_mpi [as 别名]
# 或者: from baselines.trpo_mpi.trpo_mpi import learn [as 别名]
def train(env_id, num_timesteps, seed):
whoami = mpi_fork(num_cpu)
if whoami == "parent":
return
import baselines.common.tf_util as U
logger.session().__enter__()
sess = U.single_threaded_session()
sess.__enter__()
rank = MPI.COMM_WORLD.Get_rank()
if rank != 0:
logger.set_level(logger.DISABLED)
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env = gym.make(env_id)
def policy_fn(name, ob_space, ac_space):
return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space,
hid_size=32, num_hid_layers=2)
env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json" % rank))
env.seed(workerseed)
gym.logger.setLevel(logging.WARN)
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
env.close()
示例3: train
# 需要导入模块: from baselines.trpo_mpi import trpo_mpi [as 别名]
# 或者: from baselines.trpo_mpi.trpo_mpi import learn [as 别名]
def train(env_id, num_timesteps, seed):
import baselines.common.tf_util as U
sess = U.single_threaded_session()
sess.__enter__()
rank = MPI.COMM_WORLD.Get_rank()
if rank != 0:
logger.set_level(logger.DISABLED)
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env = gym.make(env_id)
def policy_fn(name, ob_space, ac_space):
return MlpPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space,
hid_size=32, num_hid_layers=2)
env = bench.Monitor(env, logger.get_dir() and
osp.join(logger.get_dir(), str(rank)))
env.seed(workerseed)
gym.logger.setLevel(logging.WARN)
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
max_timesteps=num_timesteps, gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
env.close()
示例4: train
# 需要导入模块: from baselines.trpo_mpi import trpo_mpi [as 别名]
# 或者: from baselines.trpo_mpi.trpo_mpi import learn [as 别名]
def train(env_id, num_timesteps, seed):
from baselines.trpo_mpi.nosharing_cnn_policy import CnnPolicy
from baselines.trpo_mpi import trpo_mpi
import baselines.common.tf_util as U
rank = MPI.COMM_WORLD.Get_rank()
sess = U.single_threaded_session()
sess.__enter__()
if rank == 0:
logger.configure()
else:
logger.configure(format_strs=[])
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env = make_atari(env_id)
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
return CnnPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space)
env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
env.seed(workerseed)
env = wrap_deepmind(env)
env.seed(workerseed)
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3,
max_timesteps=int(num_timesteps * 1.1), gamma=0.98, lam=1.0, vf_iters=3, vf_stepsize=1e-4, entcoeff=0.00)
env.close()
示例5: train
# 需要导入模块: from baselines.trpo_mpi import trpo_mpi [as 别名]
# 或者: from baselines.trpo_mpi.trpo_mpi import learn [as 别名]
def train(env_id, num_timesteps, seed, num_cpu):
from baselines.trpo_mpi.nosharing_cnn_policy import CnnPolicy
from baselines.trpo_mpi import trpo_mpi
import baselines.common.tf_util as U
whoami = mpi_fork(num_cpu)
if whoami == "parent":
return
rank = MPI.COMM_WORLD.Get_rank()
sess = U.single_threaded_session()
sess.__enter__()
logger.session().__enter__()
if rank != 0:
logger.set_level(logger.DISABLED)
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env = gym.make(env_id)
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
return CnnPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space)
env = bench.Monitor(env, osp.join(logger.get_dir(), "%i.monitor.json"%rank))
env.seed(workerseed)
gym.logger.setLevel(logging.WARN)
env = wrap_train(env)
num_timesteps /= 4 # because we're wrapping the envs to do frame skip
env.seed(workerseed)
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3,
max_timesteps=num_timesteps, gamma=0.98, lam=1.0, vf_iters=3, vf_stepsize=1e-4, entcoeff=0.00)
env.close()
示例6: train
# 需要导入模块: from baselines.trpo_mpi import trpo_mpi [as 别名]
# 或者: from baselines.trpo_mpi.trpo_mpi import learn [as 别名]
def train(env_id, num_timesteps, seed):
import baselines.common.tf_util as U
sess = U.single_threaded_session()
sess.__enter__()
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
def policy_fn(name, ob_space, ac_space):
return MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
hid_size=32, num_hid_layers=2)
# Create a new base directory like /tmp/openai-2018-05-21-12-27-22-552435
log_dir = os.path.join(energyplus_logbase_dir(), datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"))
if not os.path.exists(log_dir + '/output'):
os.makedirs(log_dir + '/output')
os.environ["ENERGYPLUS_LOG"] = log_dir
model = os.getenv('ENERGYPLUS_MODEL')
if model is None:
print('Environment variable ENERGYPLUS_MODEL is not defined')
os.exit()
weather = os.getenv('ENERGYPLUS_WEATHER')
if weather is None:
print('Environment variable ENERGYPLUS_WEATHER is not defined')
os.exit()
rank = MPI.COMM_WORLD.Get_rank()
if rank == 0:
print('train: init logger with dir={}'.format(log_dir)) #XXX
logger.configure(log_dir)
else:
logger.configure(format_strs=[])
logger.set_level(logger.DISABLED)
env = make_energyplus_env(env_id, workerseed)
trpo_mpi.learn(env, policy_fn,
max_timesteps=num_timesteps,
#timesteps_per_batch=1*1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
timesteps_per_batch=16*1024, max_kl=0.01, cg_iters=10, cg_damping=0.1,
gamma=0.99, lam=0.98, vf_iters=5, vf_stepsize=1e-3)
env.close()
示例7: train
# 需要导入模块: from baselines.trpo_mpi import trpo_mpi [as 别名]
# 或者: from baselines.trpo_mpi.trpo_mpi import learn [as 别名]
def train(env_id, num_timesteps, seed):
from baselines.trpo_mpi.nosharing_cnn_policy import CnnPolicy
from baselines.trpo_mpi import trpo_mpi
import baselines.common.tf_util as U
rank = MPI.COMM_WORLD.Get_rank()
sess = U.single_threaded_session()
sess.__enter__()
if rank == 0:
logger.configure()
else:
logger.configure(format_strs=[])
workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank()
set_global_seeds(workerseed)
env = make_atari(env_id)
def policy_fn(name, ob_space, ac_space): #pylint: disable=W0613
return CnnPolicy(name=name, ob_space=env.observation_space, ac_space=env.action_space)
env = bench.Monitor(env, logger.get_dir() and osp.join(logger.get_dir(), str(rank)))
env.seed(workerseed)
gym.logger.setLevel(logging.WARN)
env = wrap_deepmind(env)
env.seed(workerseed)
trpo_mpi.learn(env, policy_fn, timesteps_per_batch=512, max_kl=0.001, cg_iters=10, cg_damping=1e-3,
max_timesteps=int(num_timesteps * 1.1), gamma=0.98, lam=1.0, vf_iters=3, vf_stepsize=1e-4, entcoeff=0.00)
env.close()
示例8: main
# 需要导入模块: from baselines.trpo_mpi import trpo_mpi [as 别名]
# 或者: from baselines.trpo_mpi.trpo_mpi import learn [as 别名]
def main():
# use fixed random state
rand_state = np.random.RandomState(1).get_state()
np.random.set_state(rand_state)
tf_set_seeds(np.random.randint(1, 2**31 - 1))
# Create the Create2 docker environment
env = Create2DockerEnv(30, port='/dev/ttyUSB0', ir_window=20, ir_history=1,
obs_history=1, dt=0.045, random_state=rand_state)
env = NormalizedEnv(env)
# Start environment processes
env.start()
# Create baselines TRPO policy function
sess = U.single_threaded_session()
sess.__enter__()
def policy_fn(name, ob_space, ac_space):
return MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
hid_size=32, num_hid_layers=2)
# Create and start plotting process
plot_running = Value('i', 1)
shared_returns = Manager().dict({"write_lock": False,
"episodic_returns": [],
"episodic_lengths": [], })
# Spawn plotting process
pp = Process(target=plot_create2_docker, args=(env, 2048, shared_returns, plot_running))
pp.start()
# Create callback function for logging data from baselines TRPO learn
kindred_callback = create_callback(shared_returns)
# Train baselines TRPO
learn(env, policy_fn,
max_timesteps=40000,
timesteps_per_batch=2048,
max_kl=0.05,
cg_iters=10,
cg_damping=0.1,
vf_iters=5,
vf_stepsize=0.001,
gamma=0.995,
lam=0.995,
callback=kindred_callback
)
# Safely terminate plotter process
plot_running.value = 0 # shutdown ploting process
time.sleep(2)
pp.join()
env.close()
示例9: main
# 需要导入模块: from baselines.trpo_mpi import trpo_mpi [as 别名]
# 或者: from baselines.trpo_mpi.trpo_mpi import learn [as 别名]
def main():
# use fixed random state
rand_state = np.random.RandomState(1).get_state()
np.random.set_state(rand_state)
tf_set_seeds(np.random.randint(1, 2**31 - 1))
# Create the Create2 mover environment
env = Create2MoverEnv(90, port='/dev/ttyUSB0', obs_history=1, dt=0.15, random_state=rand_state)
env = NormalizedEnv(env)
# Start environment processes
env.start()
# Create baselines TRPO policy function
sess = U.single_threaded_session()
sess.__enter__()
def policy_fn(name, ob_space, ac_space):
return MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
hid_size=32, num_hid_layers=2)
# Create and start plotting process
plot_running = Value('i', 1)
shared_returns = Manager().dict({"write_lock": False,
"episodic_returns": [],
"episodic_lengths": [], })
# Spawn plotting process
pp = Process(target=plot_create2_mover, args=(env, 2048, shared_returns, plot_running))
pp.start()
# Create callback function for logging data from baselines TRPO learn
kindred_callback = create_callback(shared_returns)
# Train baselines TRPO
learn(env, policy_fn,
max_timesteps=40000,
timesteps_per_batch=2048,
max_kl=0.05,
cg_iters=10,
cg_damping=0.1,
vf_iters=5,
vf_stepsize=0.001,
gamma=0.995,
lam=0.995,
callback=kindred_callback
)
# Safely terminate plotter process
plot_running.value = 0 # shutdown ploting process
time.sleep(2)
pp.join()
env.close()