当前位置: 首页>>代码示例>>Python>>正文


Python tf_util.save_state方法代码示例

本文整理汇总了Python中baselines.common.tf_util.save_state方法的典型用法代码示例。如果您正苦于以下问题:Python tf_util.save_state方法的具体用法?Python tf_util.save_state怎么用?Python tf_util.save_state使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在baselines.common.tf_util的用法示例。


在下文中一共展示了tf_util.save_state方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: maybe_save_model

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def maybe_save_model(savedir, container, state):
    """This function checkpoints the model and state of the training algorithm."""
    if savedir is None:
        return
    start_time = time.time()
    model_dir = "model-{}".format(state["num_iters"])
    U.save_state(os.path.join(savedir, model_dir, "saved"))
    if container is not None:
        container.put(os.path.join(savedir, model_dir), model_dir)
    relatively_safe_pickle_dump(state, os.path.join(savedir, 'training_state.pkl.zip'), compression=True)
    if container is not None:
        container.put(os.path.join(savedir, 'training_state.pkl.zip'), 'training_state.pkl.zip')
    relatively_safe_pickle_dump(state["monitor_state"], os.path.join(savedir, 'monitor_state.pkl'))
    if container is not None:
        container.put(os.path.join(savedir, 'monitor_state.pkl'), 'monitor_state.pkl')
    logger.log("Saved model in {} seconds\n".format(time.time() - start_time)) 
开发者ID:AdamStelmaszczyk,项目名称:learning2run,代码行数:18,代码来源:train.py

示例2: save

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def save(self, path=None):
    """Save model to a pickle located at `path`"""
    if path is None:
        path = os.path.join(logger.get_dir(), "model.pkl")

    with tempfile.TemporaryDirectory() as td:
        U.save_state(os.path.join(td, "model"))
        arc_name = os.path.join(td, "packed.zip")
        with zipfile.ZipFile(arc_name, 'w') as zipf:
            for root, dirs, files in os.walk(td):
                for fname in files:
                    file_path = os.path.join(root, fname)
                    if file_path != arc_name:
                        zipf.write(file_path, os.path.relpath(file_path, td))
        with open(arc_name, "rb") as f:
            model_data = f.read()
    with open(path, "wb") as f:
        cloudpickle.dump((model_data), f) 
开发者ID:alexsax,项目名称:midlevel-reps,代码行数:20,代码来源:pposgd_simple.py

示例3: save

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def save(self, path=None):
        """Save model to a pickle located at `path`"""
        if path is None:
            path = os.path.join(logger.get_dir(), "model.pkl")

        with tempfile.TemporaryDirectory() as td:
            U.save_state(os.path.join(td, "model"))
            arc_name = os.path.join(td, "packed.zip")
            with zipfile.ZipFile(arc_name, 'w') as zipf:
                for root, dirs, files in os.walk(td):
                    for fname in files:
                        file_path = os.path.join(root, fname)
                        if file_path != arc_name:
                            zipf.write(file_path, os.path.relpath(file_path, td))
            with open(arc_name, "rb") as f:
                model_data = f.read()
        with open(path, "wb") as f:
            cloudpickle.dump((model_data, self._act_params), f) 
开发者ID:cxxgtxy,项目名称:deeprl-baselines,代码行数:20,代码来源:simple.py

示例4: maybe_save_model

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def maybe_save_model(savedir, container, state):
    """This function checkpoints the model and state of the training algorithm."""
    if savedir is None:
        return
    start_time = time.time()
    model_dir = "model-{}".format(state["num_iters"])
    U.save_state(os.path.join(savedir, model_dir, "saved"))
    if container is not None:
        container.put(os.path.join(savedir, model_dir), model_dir)

    # requires 32gb of memory for this to work
    relatively_safe_pickle_dump(state, os.path.join(savedir, 'training_state.pkl.zip'), compression=True)
    if container is not None:
        container.put(os.path.join(savedir, 'training_state.pkl.zip'), 'training_state.pkl.zip')
    relatively_safe_pickle_dump(state["monitor_state"], os.path.join(savedir, 'monitor_state.pkl'))
    if container is not None:
        container.put(os.path.join(savedir, 'monitor_state.pkl'), 'monitor_state.pkl')
    logger.log("Saved model in {} seconds\n".format(time.time() - start_time)) 
开发者ID:Silvicek,项目名称:distributional-dqn,代码行数:20,代码来源:train_atari.py

示例5: save

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def save(self, path):
    """Save model to a pickle located at `path`"""
    with tempfile.TemporaryDirectory() as td:
      U.save_state(os.path.join(td, "model"))
      arc_name = os.path.join(td, "packed.zip")
      with zipfile.ZipFile(arc_name, 'w') as zipf:
        for root, dirs, files in os.walk(td):
          for fname in files:
            file_path = os.path.join(root, fname)
            if file_path != arc_name:
              zipf.write(file_path,
                         os.path.relpath(file_path, td))
      with open(arc_name, "rb") as f:
        model_data = f.read()
    with open(path, "wb") as f:
      dill.dump((model_data), f) 
开发者ID:chris-chris,项目名称:pysc2-examples,代码行数:18,代码来源:dqfd.py

示例6: learn

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4,
          adam_epsilon=1e-5, optim_stepsize=3e-4,
          ckpt_dir=None, log_dir=None, task_name=None,
          verbose=False):

    val_per_iter = int(max_iters/10)
    ob_space = env.observation_space
    ac_space = env.action_space
    pi = policy_func("pi", ob_space, ac_space)  # Construct network for new policy
    # placeholder
    ob = U.get_placeholder_cached(name="ob")
    ac = pi.pdtype.sample_placeholder([None])
    stochastic = U.get_placeholder_cached(name="stochastic")
    loss = tf.reduce_mean(tf.square(ac-pi.ac))
    var_list = pi.get_trainable_variables()
    adam = MpiAdam(var_list, epsilon=adam_epsilon)
    lossandgrad = U.function([ob, ac, stochastic], [loss]+[U.flatgrad(loss, var_list)])

    U.initialize()
    adam.sync()
    logger.log("Pretraining with Behavior Cloning...")
    for iter_so_far in tqdm(range(int(max_iters))):
        ob_expert, ac_expert = dataset.get_next_batch(optim_batch_size, 'train')
        train_loss, g = lossandgrad(ob_expert, ac_expert, True)
        adam.update(g, optim_stepsize)
        if verbose and iter_so_far % val_per_iter == 0:
            ob_expert, ac_expert = dataset.get_next_batch(-1, 'val')
            val_loss, _ = lossandgrad(ob_expert, ac_expert, True)
            logger.log("Training loss: {}, Validation loss: {}".format(train_loss, val_loss))

    if ckpt_dir is None:
        savedir_fname = tempfile.TemporaryDirectory().name
    else:
        savedir_fname = osp.join(ckpt_dir, task_name)
    U.save_state(savedir_fname, var_list=pi.get_variables())
    return savedir_fname 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:38,代码来源:behavior_clone.py

示例7: train

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def train(num_timesteps, seed, model_path=None):
    env_id = 'Humanoid-v2'
    from baselines.ppo1 import mlp_policy, pposgd_simple
    U.make_session(num_cpu=1).__enter__()
    def policy_fn(name, ob_space, ac_space):
        return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=64, num_hid_layers=2)
    env = make_mujoco_env(env_id, seed)

    # parameters below were the best found in a simple random search
    # these are good enough to make humanoid walk, but whether those are
    # an absolute best or not is not certain
    env = RewScale(env, 0.1)
    pi = pposgd_simple.learn(env, policy_fn,
            max_timesteps=num_timesteps,
            timesteps_per_actorbatch=2048,
            clip_param=0.2, entcoeff=0.0,
            optim_epochs=10, 
            optim_stepsize=3e-4, 
            optim_batchsize=64, 
            gamma=0.99, 
            lam=0.95,
            schedule='linear',
        )
    env.close()
    if model_path:
        U.save_state(model_path)
        
    return pi 
开发者ID:MaxSobolMark,项目名称:HardRLWithYoutube,代码行数:31,代码来源:run_humanoid.py

示例8: save

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def save(self, save_path):
        tf_util.save_state(save_path, sess=self.sess) 
开发者ID:MaxSobolMark,项目名称:HardRLWithYoutube,代码行数:4,代码来源:policies.py

示例9: save

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def save(self, path):
        """Save model to a pickle located at `path`"""
        with tempfile.TemporaryDirectory() as td:
            U.save_state(os.path.join(td, "model"))
            arc_name = os.path.join(td, "packed.zip")
            with zipfile.ZipFile(arc_name, 'w') as zipf:
                for root, dirs, files in os.walk(td):
                    for fname in files:
                        file_path = os.path.join(root, fname)
                        if file_path != arc_name:
                            zipf.write(file_path, os.path.relpath(file_path, td))
            with open(arc_name, "rb") as f:
                model_data = f.read()
        with open(path, "wb") as f:
            dill.dump((model_data, self._act_params), f) 
开发者ID:AdamStelmaszczyk,项目名称:learning2run,代码行数:17,代码来源:simple.py

示例10: train

# 需要导入模块: from baselines.common import tf_util [as 别名]
# 或者: from baselines.common.tf_util import save_state [as 别名]
def train(num_timesteps, seed, model_path=None):
    env_id = 'Humanoid-v2'
    from baselines.ppo1 import mlp_policy, pposgd_simple
    U.make_session(num_cpu=1).__enter__()
    def policy_fn(name, ob_space, ac_space):
        return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
            hid_size=64, num_hid_layers=2)
    env = make_mujoco_env(env_id, seed)

    # parameters below were the best found in a simple random search
    # these are good enough to make humanoid walk, but whether those are
    # an absolute best or not is not certain
    env = RewScale(env, 0.1)
    pi = pposgd_simple.learn(env, policy_fn,
            max_timesteps=num_timesteps,
            timesteps_per_actorbatch=2048,
            clip_param=0.2, entcoeff=0.0,
            optim_epochs=10,
            optim_stepsize=3e-4,
            optim_batchsize=64,
            gamma=0.99,
            lam=0.95,
            schedule='linear',
        )
    env.close()
    if model_path:
        U.save_state(model_path)

    return pi 
开发者ID:hiwonjoon,项目名称:ICML2019-TREX,代码行数:31,代码来源:run_humanoid.py


注:本文中的baselines.common.tf_util.save_state方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。