本文整理汇总了Python中baselines.common.tf_util.load_state方法的典型用法代码示例。如果您正苦于以下问题:Python tf_util.load_state方法的具体用法?Python tf_util.load_state怎么用?Python tf_util.load_state使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类baselines.common.tf_util
的用法示例。
在下文中一共展示了tf_util.load_state方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
logger.configure()
parser = mujoco_arg_parser()
parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
parser.set_defaults(num_timesteps=int(2e7))
args = parser.parse_args()
if not args.play:
# train the model
train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
else:
# construct the model object, load pre-trained model and render
pi = train(num_timesteps=1, seed=args.seed)
U.load_state(args.model_path)
env = make_mujoco_env('Humanoid-v2', seed=0)
ob = env.reset()
while True:
action = pi.act(stochastic=False, ob=ob)[0]
ob, _, done, _ = env.step(action)
env.render()
if done:
ob = env.reset()
示例2: maybe_load_model
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def maybe_load_model(savedir, container):
"""Load model if present at the specified path."""
if savedir is None:
return
state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip'))
if container is not None:
logger.log("Attempting to download model from Azure")
found_model = container.get(savedir, 'training_state.pkl.zip')
else:
found_model = os.path.exists(state_path)
if found_model:
state = pickle_load(state_path, compression=True)
model_dir = "model-{}".format(state["num_iters"])
if container is not None:
container.get(savedir, model_dir)
U.load_state(os.path.join(savedir, model_dir, "saved"))
logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"]))
return state
示例3: main
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
logger.configure()
parser = mujoco_arg_parser()
parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
parser.set_defaults(num_timesteps=int(2e7))
args = parser.parse_args()
if not args.play:
# train the model
train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
else:
# construct the model object, load pre-trained model and render
pi = train(num_timesteps=1, seed=args.seed)
U.load_state(args.model_path)
env = make_mujoco_env('Humanoid-v2', seed=0)
ob = env.reset()
while True:
action = pi.act(stochastic=False, ob=ob)[0]
ob, _, done, _ = env.step(action)
env.render()
if done:
ob = env.reset()
示例4: main
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
logger.configure()
parser = mujoco_arg_parser()
parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
parser.set_defaults(num_timesteps=int(5e7))
args = parser.parse_args()
if not args.play:
# train the model
train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
else:
# construct the model object, load pre-trained model and render
pi = train(num_timesteps=1, seed=args.seed)
U.load_state(args.model_path)
env = make_mujoco_env('Humanoid-v2', seed=0)
ob = env.reset()
while True:
action = pi.act(stochastic=False, ob=ob)[0]
ob, _, done, _ = env.step(action)
env.render()
if done:
ob = env.reset()
示例5: main
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
set_global_seeds(1)
args = parse_args()
with U.make_session(4) as sess: # noqa
_, env = make_env(args.env)
model_parent_path = distdeepq.parent_path(args.model_dir)
old_args = json.load(open(model_parent_path + '/args.json'))
act = distdeepq.build_act(
make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
p_dist_func=distdeepq.models.atari_model(),
num_actions=env.action_space.n,
dist_params={'Vmin': old_args['vmin'],
'Vmax': old_args['vmax'],
'nb_atoms': old_args['nb_atoms']})
U.load_state(os.path.join(args.model_dir, "saved"))
wang2015_eval(args.env, act, stochastic=args.stochastic)
示例6: runner
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def runner(env, policy_func, load_model_path, timesteps_per_batch, number_trajs,
stochastic_policy, save=False, reuse=False):
# Setup network
# ----------------------------------------
ob_space = env.observation_space
ac_space = env.action_space
pi = policy_func("pi", ob_space, ac_space, reuse=reuse)
U.initialize()
# Prepare for rollouts
# ----------------------------------------
U.load_state(load_model_path)
obs_list = []
acs_list = []
len_list = []
ret_list = []
for _ in tqdm(range(number_trajs)):
traj = traj_1_generator(pi, env, timesteps_per_batch, stochastic=stochastic_policy)
obs, acs, ep_len, ep_ret = traj['ob'], traj['ac'], traj['ep_len'], traj['ep_ret']
obs_list.append(obs)
acs_list.append(acs)
len_list.append(ep_len)
ret_list.append(ep_ret)
if stochastic_policy:
print('stochastic policy:')
else:
print('deterministic policy:')
if save:
filename = load_model_path.split('/')[-1] + '.' + env.spec.id
np.savez(filename, obs=np.array(obs_list), acs=np.array(acs_list),
lens=np.array(len_list), rets=np.array(ret_list))
avg_len = sum(len_list)/len(len_list)
avg_ret = sum(ret_list)/len(ret_list)
print("Average length:", avg_len)
print("Average return:", avg_ret)
return avg_len, avg_ret
# Sample one trajectory (until trajectory end)
示例7: load
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def load(self, load_path):
tf_util.load_state(load_path, sess=self.sess)
示例8: main
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
set_global_seeds(1)
args = parse_args()
with U.make_session(4) as sess: # noqa
_, env = make_env(args.env)
act = deepq.build_act(
make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
q_func=dueling_model if args.dueling else model,
num_actions=env.action_space.n)
U.load_state(os.path.join(args.model_dir, "saved"))
wang2015_eval(args.env, act, stochastic=args.stochastic)
示例9: load
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def load(path, num_cpu=16):
with open(path, "rb") as f:
model_data, act_params = dill.load(f)
act = deepq.build_act(**act_params)
sess = U.make_session(num_cpu=num_cpu)
sess.__enter__()
with tempfile.TemporaryDirectory() as td:
arc_path = os.path.join(td, "packed.zip")
with open(arc_path, "wb") as f:
f.write(model_data)
zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
U.load_state(os.path.join(td, "model"))
return ActWrapper(act, act_params)
示例10: load
# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def load(path, act_params, num_cpu=16):
with open(path, "rb") as f:
model_data = dill.load(f)
act = deepq.build_act(**act_params)
sess = U.make_session(num_cpu=num_cpu)
sess.__enter__()
with tempfile.TemporaryDirectory() as td:
arc_path = os.path.join(td, "packed.zip")
with open(arc_path, "wb") as f:
f.write(model_data)
zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
U.load_state(os.path.join(td, "model"))
return ActWrapper(act)