当前位置: 首页>>代码示例>>Python>>正文


Python tf_util.load_state方法代码示例

本文整理汇总了Python中baselines.common.tf_util.load_state方法的典型用法代码示例。如果您正苦于以下问题:Python tf_util.load_state方法的具体用法?Python tf_util.load_state怎么用?Python tf_util.load_state使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在baselines.common.tf_util的用法示例。


在下文中一共展示了tf_util.load_state方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
    logger.configure()
    parser = mujoco_arg_parser()
    parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
    parser.set_defaults(num_timesteps=int(2e7))
   
    args = parser.parse_args()
    
    if not args.play:
        # train the model
        train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
    else:       
        # construct the model object, load pre-trained model and render
        pi = train(num_timesteps=1, seed=args.seed)
        U.load_state(args.model_path)
        env = make_mujoco_env('Humanoid-v2', seed=0)

        ob = env.reset()        
        while True:
            action = pi.act(stochastic=False, ob=ob)[0]
            ob, _, done, _ =  env.step(action)
            env.render()
            if done:
                ob = env.reset() 
开发者ID:MaxSobolMark,项目名称:HardRLWithYoutube,代码行数:26,代码来源:run_humanoid.py

示例2: maybe_load_model

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def maybe_load_model(savedir, container):
    """Load model if present at the specified path."""
    if savedir is None:
        return

    state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip'))
    if container is not None:
        logger.log("Attempting to download model from Azure")
        found_model = container.get(savedir, 'training_state.pkl.zip')
    else:
        found_model = os.path.exists(state_path)
    if found_model:
        state = pickle_load(state_path, compression=True)
        model_dir = "model-{}".format(state["num_iters"])
        if container is not None:
            container.get(savedir, model_dir)
        U.load_state(os.path.join(savedir, model_dir, "saved"))
        logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"]))
        return state 
开发者ID:AdamStelmaszczyk,项目名称:learning2run,代码行数:21,代码来源:train.py

示例3: main

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
    logger.configure()
    parser = mujoco_arg_parser()
    parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
    parser.set_defaults(num_timesteps=int(2e7))

    args = parser.parse_args()

    if not args.play:
        # train the model
        train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
    else:
        # construct the model object, load pre-trained model and render
        pi = train(num_timesteps=1, seed=args.seed)
        U.load_state(args.model_path)
        env = make_mujoco_env('Humanoid-v2', seed=0)

        ob = env.reset()
        while True:
            action = pi.act(stochastic=False, ob=ob)[0]
            ob, _, done, _ =  env.step(action)
            env.render()
            if done:
                ob = env.reset() 
开发者ID:hiwonjoon,项目名称:ICML2019-TREX,代码行数:26,代码来源:run_humanoid.py

示例4: main

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
    logger.configure()
    parser = mujoco_arg_parser()
    parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
    parser.set_defaults(num_timesteps=int(5e7))

    args = parser.parse_args()

    if not args.play:
        # train the model
        train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
    else:
        # construct the model object, load pre-trained model and render
        pi = train(num_timesteps=1, seed=args.seed)
        U.load_state(args.model_path)
        env = make_mujoco_env('Humanoid-v2', seed=0)

        ob = env.reset()
        while True:
            action = pi.act(stochastic=False, ob=ob)[0]
            ob, _, done, _ =  env.step(action)
            env.render()
            if done:
                ob = env.reset() 
开发者ID:openai,项目名称:baselines,代码行数:26,代码来源:run_humanoid.py

示例5: main

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
    set_global_seeds(1)
    args = parse_args()

    with U.make_session(4) as sess:  # noqa
        _, env = make_env(args.env)
        model_parent_path = distdeepq.parent_path(args.model_dir)
        old_args = json.load(open(model_parent_path + '/args.json'))

        act = distdeepq.build_act(
            make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
            p_dist_func=distdeepq.models.atari_model(),
            num_actions=env.action_space.n,
            dist_params={'Vmin': old_args['vmin'],
                         'Vmax': old_args['vmax'],
                         'nb_atoms': old_args['nb_atoms']})
        U.load_state(os.path.join(args.model_dir, "saved"))
        wang2015_eval(args.env, act, stochastic=args.stochastic) 
开发者ID:Silvicek,项目名称:distributional-dqn,代码行数:20,代码来源:wang2015_eval.py

示例6: runner

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def runner(env, policy_func, load_model_path, timesteps_per_batch, number_trajs,
           stochastic_policy, save=False, reuse=False):

    # Setup network
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space, reuse=reuse)
    U.initialize()
    # Prepare for rollouts
    # ----------------------------------------
    U.load_state(load_model_path)

    obs_list = []
    acs_list = []
    len_list = []
    ret_list = []
    for _ in tqdm(range(number_trajs)):
        traj = traj_1_generator(pi, env, timesteps_per_batch, stochastic=stochastic_policy)
        obs, acs, ep_len, ep_ret = traj['ob'], traj['ac'], traj['ep_len'], traj['ep_ret']
        obs_list.append(obs)
        acs_list.append(acs)
        len_list.append(ep_len)
        ret_list.append(ep_ret)
    if stochastic_policy:
        print('stochastic policy:')
    else:
        print('deterministic policy:')
    if save:
        filename = load_model_path.split('/')[-1] + '.' + env.spec.id
        np.savez(filename, obs=np.array(obs_list), acs=np.array(acs_list),
                 lens=np.array(len_list), rets=np.array(ret_list))
    avg_len = sum(len_list)/len(len_list)
    avg_ret = sum(ret_list)/len(ret_list)
    print("Average length:", avg_len)
    print("Average return:", avg_ret)
    return avg_len, avg_ret


# Sample one trajectory (until trajectory end) 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:42,代码来源:run_mujoco.py

示例7: load

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def load(self, load_path):
        tf_util.load_state(load_path, sess=self.sess) 
开发者ID:MaxSobolMark,项目名称:HardRLWithYoutube,代码行数:4,代码来源:policies.py

示例8: main

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def main():
    set_global_seeds(1)
    args = parse_args()
    with U.make_session(4) as sess:  # noqa
        _, env = make_env(args.env)
        act = deepq.build_act(
            make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
            q_func=dueling_model if args.dueling else model,
            num_actions=env.action_space.n)

        U.load_state(os.path.join(args.model_dir, "saved"))
        wang2015_eval(args.env, act, stochastic=args.stochastic) 
开发者ID:AdamStelmaszczyk,项目名称:learning2run,代码行数:14,代码来源:wang2015_eval.py

示例9: load

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def load(path, num_cpu=16):
        with open(path, "rb") as f:
            model_data, act_params = dill.load(f)
        act = deepq.build_act(**act_params)
        sess = U.make_session(num_cpu=num_cpu)
        sess.__enter__()
        with tempfile.TemporaryDirectory() as td:
            arc_path = os.path.join(td, "packed.zip")
            with open(arc_path, "wb") as f:
                f.write(model_data)

            zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
            U.load_state(os.path.join(td, "model"))

        return ActWrapper(act, act_params) 
开发者ID:AdamStelmaszczyk,项目名称:learning2run,代码行数:17,代码来源:simple.py

示例10: load

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import load_state [as 别名]
def load(path, act_params, num_cpu=16):
    with open(path, "rb") as f:
      model_data = dill.load(f)
    act = deepq.build_act(**act_params)
    sess = U.make_session(num_cpu=num_cpu)
    sess.__enter__()
    with tempfile.TemporaryDirectory() as td:
      arc_path = os.path.join(td, "packed.zip")
      with open(arc_path, "wb") as f:
        f.write(model_data)

      zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
      U.load_state(os.path.join(td, "model"))

    return ActWrapper(act) 
开发者ID:llSourcell,项目名称:A-Guide-to-DeepMinds-StarCraft-AI-Environment,代码行数:17,代码来源:dqfd.py


注:本文中的baselines.common.tf_util.load_state方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。