當前位置: 首頁>>代碼示例>>Python>>正文


Python tf_util.load_state方法代碼示例

本文整理匯總了Python中baselines.common.tf_util.load_state方法的典型用法代碼示例。如果您正苦於以下問題:Python tf_util.load_state方法的具體用法?Python tf_util.load_state怎麽用?Python tf_util.load_state使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在baselines.common.tf_util的用法示例。


在下文中一共展示了tf_util.load_state方法的10個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: main

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def main():
    logger.configure()
    parser = mujoco_arg_parser()
    parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
    parser.set_defaults(num_timesteps=int(2e7))
   
    args = parser.parse_args()
    
    if not args.play:
        # train the model
        train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
    else:       
        # construct the model object, load pre-trained model and render
        pi = train(num_timesteps=1, seed=args.seed)
        U.load_state(args.model_path)
        env = make_mujoco_env('Humanoid-v2', seed=0)

        ob = env.reset()        
        while True:
            action = pi.act(stochastic=False, ob=ob)[0]
            ob, _, done, _ =  env.step(action)
            env.render()
            if done:
                ob = env.reset() 
開發者ID:MaxSobolMark,項目名稱:HardRLWithYoutube,代碼行數:26,代碼來源:run_humanoid.py

示例2: maybe_load_model

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def maybe_load_model(savedir, container):
    """Load model if present at the specified path."""
    if savedir is None:
        return

    state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip'))
    if container is not None:
        logger.log("Attempting to download model from Azure")
        found_model = container.get(savedir, 'training_state.pkl.zip')
    else:
        found_model = os.path.exists(state_path)
    if found_model:
        state = pickle_load(state_path, compression=True)
        model_dir = "model-{}".format(state["num_iters"])
        if container is not None:
            container.get(savedir, model_dir)
        U.load_state(os.path.join(savedir, model_dir, "saved"))
        logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"]))
        return state 
開發者ID:AdamStelmaszczyk,項目名稱:learning2run,代碼行數:21,代碼來源:train.py

示例3: main

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def main():
    logger.configure()
    parser = mujoco_arg_parser()
    parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
    parser.set_defaults(num_timesteps=int(2e7))

    args = parser.parse_args()

    if not args.play:
        # train the model
        train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
    else:
        # construct the model object, load pre-trained model and render
        pi = train(num_timesteps=1, seed=args.seed)
        U.load_state(args.model_path)
        env = make_mujoco_env('Humanoid-v2', seed=0)

        ob = env.reset()
        while True:
            action = pi.act(stochastic=False, ob=ob)[0]
            ob, _, done, _ =  env.step(action)
            env.render()
            if done:
                ob = env.reset() 
開發者ID:hiwonjoon,項目名稱:ICML2019-TREX,代碼行數:26,代碼來源:run_humanoid.py

示例4: main

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def main():
    logger.configure()
    parser = mujoco_arg_parser()
    parser.add_argument('--model-path', default=os.path.join(logger.get_dir(), 'humanoid_policy'))
    parser.set_defaults(num_timesteps=int(5e7))

    args = parser.parse_args()

    if not args.play:
        # train the model
        train(num_timesteps=args.num_timesteps, seed=args.seed, model_path=args.model_path)
    else:
        # construct the model object, load pre-trained model and render
        pi = train(num_timesteps=1, seed=args.seed)
        U.load_state(args.model_path)
        env = make_mujoco_env('Humanoid-v2', seed=0)

        ob = env.reset()
        while True:
            action = pi.act(stochastic=False, ob=ob)[0]
            ob, _, done, _ =  env.step(action)
            env.render()
            if done:
                ob = env.reset() 
開發者ID:openai,項目名稱:baselines,代碼行數:26,代碼來源:run_humanoid.py

示例5: main

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def main():
    set_global_seeds(1)
    args = parse_args()

    with U.make_session(4) as sess:  # noqa
        _, env = make_env(args.env)
        model_parent_path = distdeepq.parent_path(args.model_dir)
        old_args = json.load(open(model_parent_path + '/args.json'))

        act = distdeepq.build_act(
            make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
            p_dist_func=distdeepq.models.atari_model(),
            num_actions=env.action_space.n,
            dist_params={'Vmin': old_args['vmin'],
                         'Vmax': old_args['vmax'],
                         'nb_atoms': old_args['nb_atoms']})
        U.load_state(os.path.join(args.model_dir, "saved"))
        wang2015_eval(args.env, act, stochastic=args.stochastic) 
開發者ID:Silvicek,項目名稱:distributional-dqn,代碼行數:20,代碼來源:wang2015_eval.py

示例6: runner

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def runner(env, policy_func, load_model_path, timesteps_per_batch, number_trajs,
           stochastic_policy, save=False, reuse=False):

    # Setup network
    # ----------------------------------------
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space, reuse=reuse)
    U.initialize()
    # Prepare for rollouts
    # ----------------------------------------
    U.load_state(load_model_path)

    obs_list = []
    acs_list = []
    len_list = []
    ret_list = []
    for _ in tqdm(range(number_trajs)):
        traj = traj_1_generator(pi, env, timesteps_per_batch, stochastic=stochastic_policy)
        obs, acs, ep_len, ep_ret = traj['ob'], traj['ac'], traj['ep_len'], traj['ep_ret']
        obs_list.append(obs)
        acs_list.append(acs)
        len_list.append(ep_len)
        ret_list.append(ep_ret)
    if stochastic_policy:
        print('stochastic policy:')
    else:
        print('deterministic policy:')
    if save:
        filename = load_model_path.split('/')[-1] + '.' + env.spec.id
        np.savez(filename, obs=np.array(obs_list), acs=np.array(acs_list),
                 lens=np.array(len_list), rets=np.array(ret_list))
    avg_len = sum(len_list)/len(len_list)
    avg_ret = sum(ret_list)/len(ret_list)
    print("Average length:", avg_len)
    print("Average return:", avg_ret)
    return avg_len, avg_ret


# Sample one trajectory (until trajectory end) 
開發者ID:Hwhitetooth,項目名稱:lirpg,代碼行數:42,代碼來源:run_mujoco.py

示例7: load

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def load(self, load_path):
        tf_util.load_state(load_path, sess=self.sess) 
開發者ID:MaxSobolMark,項目名稱:HardRLWithYoutube,代碼行數:4,代碼來源:policies.py

示例8: main

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def main():
    set_global_seeds(1)
    args = parse_args()
    with U.make_session(4) as sess:  # noqa
        _, env = make_env(args.env)
        act = deepq.build_act(
            make_obs_ph=lambda name: U.Uint8Input(env.observation_space.shape, name=name),
            q_func=dueling_model if args.dueling else model,
            num_actions=env.action_space.n)

        U.load_state(os.path.join(args.model_dir, "saved"))
        wang2015_eval(args.env, act, stochastic=args.stochastic) 
開發者ID:AdamStelmaszczyk,項目名稱:learning2run,代碼行數:14,代碼來源:wang2015_eval.py

示例9: load

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def load(path, num_cpu=16):
        with open(path, "rb") as f:
            model_data, act_params = dill.load(f)
        act = deepq.build_act(**act_params)
        sess = U.make_session(num_cpu=num_cpu)
        sess.__enter__()
        with tempfile.TemporaryDirectory() as td:
            arc_path = os.path.join(td, "packed.zip")
            with open(arc_path, "wb") as f:
                f.write(model_data)

            zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
            U.load_state(os.path.join(td, "model"))

        return ActWrapper(act, act_params) 
開發者ID:AdamStelmaszczyk,項目名稱:learning2run,代碼行數:17,代碼來源:simple.py

示例10: load

# 需要導入模塊: from baselines.common import tf_util [as 別名]
# 或者: from baselines.common.tf_util import load_state [as 別名]
def load(path, act_params, num_cpu=16):
    with open(path, "rb") as f:
      model_data = dill.load(f)
    act = deepq.build_act(**act_params)
    sess = U.make_session(num_cpu=num_cpu)
    sess.__enter__()
    with tempfile.TemporaryDirectory() as td:
      arc_path = os.path.join(td, "packed.zip")
      with open(arc_path, "wb") as f:
        f.write(model_data)

      zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
      U.load_state(os.path.join(td, "model"))

    return ActWrapper(act) 
開發者ID:llSourcell,項目名稱:A-Guide-to-DeepMinds-StarCraft-AI-Environment,代碼行數:17,代碼來源:dqfd.py


注:本文中的baselines.common.tf_util.load_state方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。