当前位置: 首页>>代码示例>>Python>>正文


Python cloudpickle.load方法代码示例

本文整理汇总了Python中cloudpickle.load方法的典型用法代码示例。如果您正苦于以下问题:Python cloudpickle.load方法的具体用法?Python cloudpickle.load怎么用?Python cloudpickle.load使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在cloudpickle的用法示例。


在下文中一共展示了cloudpickle.load方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: load

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def load(cls, directory: str):
        with open(os.path.join(directory, 'elmo.pkl'), 'rb') as f:
            params = cloudpickle.load(f)

        guesser = ElmoGuesser(params['config_num'])
        guesser.class_to_i = params['class_to_i']
        guesser.i_to_class = params['i_to_class']
        guesser.random_seed = params['random_seed']
        guesser.dropout = params['dropout']
        guesser.model = ElmoModel(len(guesser.i_to_class))
        guesser.model.load_state_dict(torch.load(
            os.path.join(directory, 'elmo.pt'), map_location=lambda storage, loc: storage
        ))
        guesser.model.eval()
        if CUDA:
            guesser.model = guesser.model.cuda()
        return guesser 
开发者ID:Pinafore,项目名称:qb,代码行数:19,代码来源:elmo.py

示例2: extend_and_update

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def extend_and_update(self, **train_kwargs) -> int:
        """Extend internal batch of data and train.

        Specifically, this method will load new transitions (if necessary), train
        the model for a while, and advance the round counter. If there are no fresh
        demonstrations in the demonstration directory for the current round, then
        this will raise a `NeedsDemosException` instead of training or advancing
        the round counter. In that case, the user should call
        `.get_trajectory_collector()` and use the returned
        `InteractiveTrajectoryCollector` to produce a new set of demonstrations for
        the current interaction round.

        Arguments:
          **train_kwargs: arguments to pass to `BC.train()`.

        Returns:
          round_num: new round number after advancing the round counter.
        """
        tf.logging.info("Loading demonstrations")
        self._try_load_demos()
        tf.logging.info(f"Training at round {self.round_num}")
        self.bc_trainer.train(**train_kwargs)
        self.round_num += 1
        tf.logging.info(f"New round number is {self.round_num}")
        return self.round_num 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:27,代码来源:dagger.py

示例3: _setup_load_operations

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def _setup_load_operations(self):
        """
        Create tensorflow operations for loading model parameters
        """
        # Assume tensorflow graphs are static -> check
        # that we only call this function once
        if self._param_load_ops is not None:
            raise RuntimeError("Parameter load operations have already been created")
        # For each loadable parameter, create appropiate
        # placeholder and an assign op, and store them to
        # self.load_param_ops as dict of variable.name -> (placeholder, assign)
        loadable_parameters = self.get_parameter_list()
        # Use OrderedDict to store order for backwards compatibility with
        # list-based params
        self._param_load_ops = OrderedDict()
        with self.graph.as_default():
            for param in loadable_parameters:
                placeholder = tf.placeholder(dtype=param.dtype, shape=param.shape)
                # param.name is unique (tensorflow variables have unique names)
                self._param_load_ops[param.name] = (placeholder, param.assign(placeholder)) 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:22,代码来源:base_class.py

示例4: load

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def load(cls, load_path, env=None, custom_objects=None, **kwargs):
        """
        Load the model from file

        :param load_path: (str or file-like) the saved parameter location
        :param env: (Gym Environment) the new environment to run the loaded model on
            (can be None if you only need prediction from a trained model)
        :param custom_objects: (dict) Dictionary of objects to replace
            upon loading. If a variable is present in this dictionary as a
            key, it will not be deserialized and the corresponding item
            will be used instead. Similar to custom_objects in
            `keras.models.load_model`. Useful when you have an object in
            file that can not be deserialized.
        :param kwargs: extra arguments to change the model when loading
        """
        raise NotImplementedError() 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:18,代码来源:base_class.py

示例5: _load_from_file_cloudpickle

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def _load_from_file_cloudpickle(load_path):
        """Legacy code for loading older models stored with cloudpickle

        :param load_path: (str or file-like) where from to load the file
        :return: (dict, OrderedDict) Class parameters and model parameters
        """
        if isinstance(load_path, str):
            if not os.path.exists(load_path):
                if os.path.exists(load_path + ".pkl"):
                    load_path += ".pkl"
                else:
                    raise ValueError("Error: the file {} could not be found".format(load_path))

            with open(load_path, "rb") as file_:
                data, params = cloudpickle.load(file_)
        else:
            # Here load_path is a file-like object, not a path
            data, params = cloudpickle.load(load_path)

        return data, params 
开发者ID:Stable-Baselines-Team,项目名称:stable-baselines,代码行数:22,代码来源:base_class.py

示例6: load

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def load(path):
        with open(path, "rb") as f:
            model_data, act_params = cloudpickle.load(f)
        act = deepq.build_act(**act_params)
        tf_config = tf.ConfigProto()
        tf_config.gpu_options.allow_growth = True
        sess = tf.Session(config=tf_config)
        sess.__enter__()
        with tempfile.TemporaryDirectory() as td:
            arc_path = os.path.join(td, "packed.zip")
            with open(arc_path, "wb") as f:
                f.write(model_data)

            zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
            load_state(os.path.join(td, "model"))

        return ActWrapper(act, act_params) 
开发者ID:ArztSamuel,项目名称:DRL_DeliveryDuel,代码行数:19,代码来源:simple.py

示例7: load

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def load(self, models_dir):
        try:
            del self.model

            tf.keras.backend.clear_session()

            self.model = tf.keras.models.load_model(
                os.path.join(models_dir, "tf_intent_model.hd5"), compile=True)

            self.graph = tf.get_default_graph()

            print("Tf model loaded")

            with open(os.path.join(models_dir, "labels.pkl"), 'rb') as f:
                self.label_encoder = cloudpickle.load(f)
                print("Labels model loaded")

        except IOError:
            return False 
开发者ID:alfredfrancis,项目名称:ai-chatbot-framework,代码行数:21,代码来源:tf_intent_classifer.py

示例8: _load_model

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def _load_model(model_path, keras_module, **kwargs):
    keras_models = importlib.import_module(keras_module.__name__ + ".models")
    custom_objects = kwargs.pop("custom_objects", {})
    custom_objects_path = None
    if os.path.isdir(model_path):
        if os.path.isfile(os.path.join(model_path, _CUSTOM_OBJECTS_SAVE_PATH)):
            custom_objects_path = os.path.join(model_path, _CUSTOM_OBJECTS_SAVE_PATH)
        model_path = os.path.join(model_path, _MODEL_SAVE_PATH)
    if custom_objects_path is not None:
        import cloudpickle
        with open(custom_objects_path, "rb") as in_f:
            pickled_custom_objects = cloudpickle.load(in_f)
            pickled_custom_objects.update(custom_objects)
            custom_objects = pickled_custom_objects
    from distutils.version import StrictVersion
    if StrictVersion(keras_module.__version__.split('-')[0]) >= StrictVersion("2.2.3"):
        # NOTE: Keras 2.2.3 does not work with unicode paths in python2. Pass in h5py.File instead
        # of string to avoid issues.
        import h5py
        with h5py.File(os.path.abspath(model_path), "r") as model_path:
            return keras_models.load_model(model_path, custom_objects=custom_objects, **kwargs)
    else:
        # NOTE: Older versions of Keras only handle filepath.
        return keras_models.load_model(model_path, custom_objects=custom_objects, **kwargs) 
开发者ID:mlflow,项目名称:mlflow,代码行数:26,代码来源:keras.py

示例9: load

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def load(path):
        with open(path, "rb") as f:
            model_data, act_params = cloudpickle.load(f)
        act = deepq.build_act(**act_params)
        sess = tf.Session()
        sess.__enter__()
        with tempfile.TemporaryDirectory() as td:
            arc_path = os.path.join(td, "packed.zip")
            with open(arc_path, "wb") as f:
                f.write(model_data)

            zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
            load_state(os.path.join(td, "model"))

        return ActWrapper(act, act_params) 
开发者ID:Hwhitetooth,项目名称:lirpg,代码行数:17,代码来源:simple.py

示例10: load_act

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def load_act(path):
        with open(path, "rb") as f:
            model_data, act_params = cloudpickle.load(f)
        act = deepq.build_act(**act_params)
        sess = tf.Session()
        sess.__enter__()
        with tempfile.TemporaryDirectory() as td:
            arc_path = os.path.join(td, "packed.zip")
            with open(arc_path, "wb") as f:
                f.write(model_data)

            zipfile.ZipFile(arc_path, 'r', zipfile.ZIP_DEFLATED).extractall(td)
            load_state(os.path.join(td, "model"))

        return ActWrapper(act, act_params) 
开发者ID:MaxSobolMark,项目名称:HardRLWithYoutube,代码行数:17,代码来源:deepq.py

示例11: load

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def load(cls, directory: str):
        with open(os.path.join(directory, 'rnn.pkl'), 'rb') as f:
            params = cloudpickle.load(f)

        guesser = RnnGuesser(params['config_num'])
        guesser.page_field = params['page_field']
        guesser.qanta_id_field = params['qanta_id_field']

        guesser.text_field = params['text_field']

        guesser.n_classes = params['n_classes']
        guesser.gradient_clip = params['gradient_clip']
        guesser.n_hidden_units = params['n_hidden_units']
        guesser.n_hidden_layers = params['n_hidden_layers']
        guesser.nn_dropout = params['nn_dropout']
        guesser.use_wiki = params['use_wiki']
        guesser.n_wiki_sentences = params['n_wiki_sentences']
        guesser.wiki_title_replace_token = params['wiki_title_replace_token']
        guesser.lowercase = params['lowercase']
        guesser.random_seed = params['random_seed']
        guesser.model = RnnModel(
            guesser.n_classes,
            text_field=guesser.text_field,
            init_embeddings=False, emb_dim=300,
            n_hidden_layers=guesser.n_hidden_layers,
            n_hidden_units=guesser.n_hidden_units
        )
        guesser.model.load_state_dict(torch.load(
            os.path.join(directory, 'rnn.pt'), map_location=lambda storage, loc: storage
        ))
        guesser.model.eval()
        if CUDA:
            guesser.model = guesser.model.cuda()
        return guesser 
开发者ID:Pinafore,项目名称:qb,代码行数:36,代码来源:rnn.py

示例12: _load_trajectory

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def _load_trajectory(npz_path: str) -> types.Trajectory:
    """Load a single trajectory from a compressed Numpy file."""
    np_data = np.load(npz_path, allow_pickle=True)
    has_rew = "rews" in np_data
    cls = types.TrajectoryWithRew if has_rew else types.Trajectory
    return cls(**dict(np_data.items())) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:8,代码来源:dagger.py

示例13: reconstruct_trainer

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def reconstruct_trainer(cls, scratch_dir: str) -> "DAggerTrainer":
        """Reconstruct trainer from the latest snapshot in some working directory.

        Args:
          scratch_dir: path to the working directory created by a previous run of
            this algorithm. The directory should contain a
            `checkpoint-latest.pkl` file.

        Returns:
          trainer: a reconstructed `DAggerTrainer` with the same state as the
            previously-saved one.
        """
        checkpoint_path = os.path.join(scratch_dir, "checkpoint-latest.pkl")
        with open(checkpoint_path, "rb") as fp:
            saved_trainer = cloudpickle.load(fp)
        # reconstruct from old init args
        trainer = cls(**saved_trainer["init_args"])
        # set TF variables
        set_tf_vars(
            values=saved_trainer["variable_values"],
            tf_vars=trainer._vars,
            sess=trainer.bc_trainer.sess,
        )
        for attr_name in cls.SAVE_ATTRS:
            attr_value = saved_trainer["saved_attrs"][attr_name]
            setattr(trainer, attr_name, attr_value)
        return trainer 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:29,代码来源:dagger.py

示例14: set_tf_vars

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def set_tf_vars(
    *,
    values: List[np.ndarray],
    scope: Optional[str] = None,
    tf_vars: Optional[List[tf.Variable]] = None,
    sess: Optional[tf.Session] = None,
):
    """Set a collection of variables to take the values in `values`.

    Variables can be either specified by scope or passed directly into the
    function as a list. Variables and values will be matched based on the order
    in which they appear in their respective collections, so there must be as
    many values as variables.

    Args:
        values: list of values to load into variables.
        scope: scope to collect variables from. Either this argument xor
          `tf_vars` must be given.
        tf_vars: explicit list of TF variables to write to. Mutex with `scope`.
        sess: TF session to use, if not the default.
    """
    if scope is not None:
        assert tf_vars is None, "must give either `tf_vars` xor `scope` kwargs"
        tf_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
    else:
        assert tf_vars is not None, "must give either `tf_vars` xor `scope` kwargs"
    assert len(tf_vars) == len(values), (
        f"{len(tf_vars)} tf variables but " f"{len(values)} values supplied"
    )
    sess = sess or tf.get_default_session()
    assert sess is not None, "must supply session or have one in context"
    placeholders = [tf.placeholder(shape=v.shape, dtype=v.dtype) for v in tf_vars]
    assign_ops = [tf.assign(var, ph) for var, ph in zip(tf_vars, placeholders)]
    sess.run(
        assign_ops, feed_dict={ph: value for ph, value in zip(placeholders, values)}
    ) 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:38,代码来源:bc.py

示例15: reconstruct_policy

# 需要导入模块: import cloudpickle [as 别名]
# 或者: from cloudpickle import load [as 别名]
def reconstruct_policy(
        policy_path: str, sess: Optional[tf.Session] = None,
    ) -> BasePolicy:
        """Reconstruct a saved policy.

        Args:
            policy_path: path a policy produced by `.save_policy()`.
            sess: optional session to construct policy under,
              if not the default session.

        Returns:
            policy: policy with reloaded weights.
        """
        if sess is None:
            sess = tf.get_default_session()
            assert sess is not None, "must supply session via kwarg or context mgr"

        # re-read data from dict
        with open(policy_path, "rb") as fp:
            loaded_pickle = cloudpickle.load(fp)

        # construct the policy class
        klass = loaded_pickle["class"]
        kwargs = loaded_pickle["kwargs"]
        with tf.variable_scope("reconstructed_policy"):
            rv_pol = klass(sess=sess, **kwargs)
            inner_scope = tf.get_variable_scope().name

        # set values for the new policy's parameters
        param_values = loaded_pickle["params"]
        set_tf_vars(values=param_values, scope=inner_scope, sess=sess)

        return rv_pol 
开发者ID:HumanCompatibleAI,项目名称:imitation,代码行数:35,代码来源:bc.py


注:本文中的cloudpickle.load方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。