当前位置: 首页>>代码示例>>Python>>正文


Python pywrap_tensorflow.NewCheckpointReader方法代码示例

本文整理汇总了Python中tensorflow.python.pywrap_tensorflow.NewCheckpointReader方法的典型用法代码示例。如果您正苦于以下问题:Python pywrap_tensorflow.NewCheckpointReader方法的具体用法?Python pywrap_tensorflow.NewCheckpointReader怎么用?Python pywrap_tensorflow.NewCheckpointReader使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.pywrap_tensorflow的用法示例。


在下文中一共展示了pywrap_tensorflow.NewCheckpointReader方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: print_tensors_in_checkpoint_file

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        if all_tensors:
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                print("tensor_name: ", key)
                print(reader.get_tensor(key))
        elif not tensor_name:
            print(reader.debug_string().decode("utf-8"))
        else:
            print("tensor_name: ", tensor_name)
            print(reader.get_tensor(tensor_name))
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.") 
开发者ID:HiKapok,项目名称:tf.fashionAI,代码行数:20,代码来源:inspect_checkpoint.py

示例2: resore_form_rl_net

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def resore_form_rl_net(self,ckpt_name, graph, sess):
        print("Restore form RL net")

        print("===== Prase data from %s =====" % ckpt_name)
        net_prefix = 'pi/pi'
        reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for _key in var_to_shape_map:
            print(_key)
            # print("tensor_name: ", key)
            # print(reader.get_tensor(key))
            # tensor = graph.get_tensor_by_name(key)
            if (str(_key).startswith('%s/net/'%net_prefix) or
                str(_key).startswith('%s/Trajectory_follower_mlp_net/'%net_prefix)):
                notaion_list =  [m.start() for m in re.finditer('/', _key)]
                key = _key[int(notaion_list[1]+1):len(_key)]+ ":0"
                # print(key)
                try:
                    tensor = graph.get_tensor_by_name(key)
                    sess.run(tf.assign(tensor, reader.get_tensor(_key)))
                    # print(tensor)
                except Exception as e:
                    print(key, " can not be restored, e= ",str(e))
                    pass 
开发者ID:hku-mars,项目名称:crossgap_il_rl,代码行数:26,代码来源:tf_policy_network.py

示例3: resort_para_form_checkpoint

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def resort_para_form_checkpoint(self, _ckpt_name_vec, graph, sess):
        # with tf.name_scope("restore"):
        if( isinstance(_ckpt_name_vec, list)):
            ckpt_name_vec = _ckpt_name_vec
        else:
            ckpt_name_vec = [_ckpt_name_vec]

        with tf.name_scope ("restore"):
            for ckpt_name in ckpt_name_vec:
                print("===== Restore data from %s =====" % ckpt_name)
                reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    # print("tensor_name: ", key)
                    # print(reader.get_tensor(key))
                    # tensor = graph.get_tensor_by_name(key)
                    try:
                        tensor = graph.get_tensor_by_name(key + ":0")
                        sess.run(tf.assign(tensor, reader.get_tensor(key)))
                        # print(tensor)
                    except:
                        # print(key, " can not be restored")
                        pass 
开发者ID:hku-mars,项目名称:crossgap_il_rl,代码行数:25,代码来源:tf_policy_network.py

示例4: create_model_from_npz_file

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def create_model_from_npz_file(npz, model, target):
    """Creates a tensorflow model from a given npz structure in which the variables for the desired model are stored.
        npz: Path to the npz structure containing files representing the variables in the model.
        model: Path in which the final model should be stored
        target: A target model which contains the desired names for the structure
    """
    reader = pywrap_tensorflow.NewCheckpointReader(target)
    target_map = reader.get_variable_to_shape_map()

    variables = variables_dictionary_from_npz_file(npz)
    i = 0
    for key in variables:

        if key_contained_in_map(key, target_map):
            name = 'var' + str(i)
            val = tf.Variable(variables[key], name=key)
            exec(name + " = val")
            i += 1

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        save_path = saver.save(sess, model)
        print("Model saved in file: %s" % save_path) 
开发者ID:cgtuebingen,项目名称:will-people-like-your-image,代码行数:26,代码来源:npz_file_to_checkpoint.py

示例5: print_all_in_ckpt

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def print_all_in_ckpt(ckpt_path):
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print('{} {}\n'.format(key, reader.get_tensor(key).shape))
        #f.write('{}\n'.format(reader.get_tensor(key).shape)) 
开发者ID:fab-jul,项目名称:imgcomp-cvpr,代码行数:8,代码来源:save_mapper.py

示例6: get_all_variable_names

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def get_all_variable_names(ckpt_path):
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    return list(var_to_shape_map.keys()) 
开发者ID:fab-jul,项目名称:imgcomp-cvpr,代码行数:6,代码来源:save_mapper.py

示例7: getTransformEstimator

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def getTransformEstimator(cls, priors: Tuple[Distribution, ...], K: int,
                              chptFile: str, dtype: tf.DType = tf.float32,
                              noiseUniformity: NoiseUniformity = HOMOGENEOUS,
                              stopCriterionInit=LlhStall(10),
                              stopCriterionEM=LlhStall(100),
                              stopCriterionBCD=LlhImprovementThreshold(1e-2),
                              path: str = "/tmp", device: str = "/cpu:0"):
        # configuring warm start settings
        reader = pywrap_tensorflow.NewCheckpointReader(chptFile)
        varList = [v for v in reader.get_variable_to_shape_map().keys()
                   if (v != "U/0" and
                       v != "global_step" and
                       v != "stop" and
                       not v.startswith(f"stopCriterion{Phase.INIT.name}/") and
                       not v.startswith(f"stopCriterion{Phase.EM.name}/") and
                       not v.startswith(f"stopCriterion{Phase.BCD.name}/"))]
        wsVars = "|".join(varList)
        ws = tf.estimator.WarmStartSettings(ckpt_to_initialize_from=chptFile,
                                            vars_to_warm_start=wsVars)

        def model_fn(features, labels, mode):
            es = cls.__estimatorSpec(mode=mode, features=features,
                                     isFullyObserved=True,
                                     device=device, priors=priors,
                                     noiseUniformity=noiseUniformity,
                                     stopCriterionInit=stopCriterionInit,
                                     stopCriterionEM=stopCriterionEM,
                                     stopCriterionBCD=stopCriterionBCD,
                                     K=K, path=path, cv=None,
                                     transform=True, dtype=dtype)
            return(es)

        est = tf.estimator.Estimator(model_fn=model_fn,
                                     model_dir=path,
                                     warm_start_from=ws)
        return(est) 
开发者ID:bethgelab,项目名称:decompose,代码行数:38,代码来源:tensorFactorisation.py

示例8: fit_transform

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def fit_transform(self, X: np.ndarray) -> np.ndarray:
        self.fit(X)
        ckptFile = self.__tefa.latest_checkpoint()
        ckptReader = pywrap_tensorflow.NewCheckpointReader(ckptFile)
        U0 = ckptReader.get_tensor("U/0")
        return(U0) 
开发者ID:bethgelab,项目名称:decompose,代码行数:8,代码来源:sklearn.py

示例9: transform

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def transform(self, X: np.ndarray,
                  transformModelDirectory: str) -> np.ndarray:
        # create input_fn
        x = {"test": X.astype(self.__dtype)}
        input_fn = tf.estimator.inputs.numpy_input_fn(
            x, y=None, batch_size=X.shape[0],
            shuffle=False, num_epochs=None)

        ckptFile = self.__tefa.latest_checkpoint()
        tefaTransform = TensorFactorisation.getTransformEstimator(
            priors=self.__priors,
            K=self.n_components,
            dtype=tf.as_dtype(self.__dtype),
            path=transformModelDirectory,
            chptFile=ckptFile,
            noiseUniformity=self.noiseUniformity,
            stopCriterionInit=self.__stopCriterionInit,
            stopCriterionEM=self.__stopCriterionEM,
            stopCriterionBCD=self.__stopCriterionBCD)
        tefaTransform.train(input_fn=input_fn,
                            steps=self.__maxIterations,
                            hooks=[StopHook()])
        ckptFile = tefaTransform.latest_checkpoint()
        ckptReader = pywrap_tensorflow.NewCheckpointReader(ckptFile)
        U0 = ckptReader.get_tensor("U/0tr")
        return(U0) 
开发者ID:bethgelab,项目名称:decompose,代码行数:28,代码来源:sklearn.py

示例10: print_checkpoint

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def print_checkpoint(path:str, name:str):

    path = get_model_name(path)    
    reader = pywrap_tensorflow.NewCheckpointReader(path)
    if not name:
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in sorted(var_to_shape_map):
            tensor = reader.get_tensor(key)
            print('{}:{}'.format(key, tensor.shape))
            print(tensor)    
    else:
        print(name)
        print(reader.get_tensor(name)) 
开发者ID:hcmlab,项目名称:vadnet,代码行数:15,代码来源:checkpoint.py

示例11: get_var_from_checkpoint

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def get_var_from_checkpoint(path:str, tensor_name:str):

    path = get_model_name(path)
    reader = pywrap_tensorflow.NewCheckpointReader(path)
    return reader.get_tensor(tensor_name) 
开发者ID:hcmlab,项目名称:vadnet,代码行数:7,代码来源:checkpoint.py

示例12: get_layers

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def get_layers(path):

    if os.path.isdir(path):
        path = tf.train.get_checkpoint_state(path).model_checkpoint_path

    reader = pywrap_tensorflow.NewCheckpointReader(path)    
    return reader.get_variable_to_shape_map() 
开发者ID:hcmlab,项目名称:vadnet,代码行数:9,代码来源:weights.py

示例13: get_weights

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def get_weights(path, name):
    
    if os.path.isdir(path):
        path = tf.train.get_checkpoint_state(path).model_checkpoint_path
    
    reader = pywrap_tensorflow.NewCheckpointReader(path)    
    return reader.get_tensor(name) 
开发者ID:hcmlab,项目名称:vadnet,代码行数:9,代码来源:weights.py

示例14: load_inception_resnet_v2

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def load_inception_resnet_v2(checkpoint_path, enable_aux=False):
    model = InceptionResnetV2(enable_aux=enable_aux)
    with chainer.no_backprop_mode():
        model(np.random.randn(2, 3, 299, 299).astype('f'))  # initialize params
    if _tf_import_error is not None:
        raise RuntimeError('could not import tensorflow; the import error as follows:\n' + str(_tf_import_error))
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    model.load_tf_checkpoint(reader, 'InceptionResnetV2')
    return model 
开发者ID:pfnet-research,项目名称:nips17-adversarial-attack,代码行数:11,代码来源:inception_resnet_v2.py

示例15: print_tensors_in_checkpoint_file

# 需要导入模块: from tensorflow.python import pywrap_tensorflow [as 别名]
# 或者: from tensorflow.python.pywrap_tensorflow import NewCheckpointReader [as 别名]
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
  """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
    all_tensors: Boolean indicating whether to print all tensors.
  """
  try:
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    if all_tensors:
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in sorted(var_to_shape_map):
        print("tensor_name: ", key)
        print(reader.get_tensor(key))
    elif not tensor_name:
      print(reader.debug_string().decode("utf-8"))
    else:
      print("tensor_name: ", tensor_name)
      print(reader.get_tensor(tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")
    if ("Data loss" in str(e) and
        (any([e in file_name for e in [".index", ".meta", ".data"]]))):
      proposed_file = ".".join(file_name.split(".")[0:-1])
      v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide the filename
*prefix*.  Try removing the '.' and extension.  Try:
inspect checkpoint --file_name = {}"""
      print(v2_file_error_template.format(proposed_file)) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:40,代码来源:inspect_checkpoint.py


注:本文中的tensorflow.python.pywrap_tensorflow.NewCheckpointReader方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。