当前位置: 首页>>代码示例>>Python>>正文


Python checkpoint_utils.list_variables方法代码示例

本文整理汇总了Python中tensorflow.contrib.framework.python.framework.checkpoint_utils.list_variables方法的典型用法代码示例。如果您正苦于以下问题:Python checkpoint_utils.list_variables方法的具体用法?Python checkpoint_utils.list_variables怎么用?Python checkpoint_utils.list_variables使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.framework.python.framework.checkpoint_utils的用法示例。


在下文中一共展示了checkpoint_utils.list_variables方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: print_tensors_in_checkpoint_file

# 需要导入模块: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 别名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 别名]
def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    if not tensor_name:
      variables = checkpoint_utils.list_variables(file_name)
      for name, shape in variables:
        print("%s\t%s" % (name, str(shape)))
    else:
      print("tensor_name: ", tensor_name)
      print(checkpoint_utils.load_variable(file_name, tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.") 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:27,代码来源:inspect_checkpoint.py

示例2: main

# 需要导入模块: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 别名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 别名]
def main(argv):
    var_list = checkpoint_utils.list_variables(FLAGS.path)
    with open(FLAGS.json + ".json", 'w', encoding="utf-8") as f:
        json.dump(var_list, f, ensure_ascii=False, indent=4)
    for v in var_list:
        print(v) 
开发者ID:uizard-technologies,项目名称:realmix,代码行数:8,代码来源:inspect_variables.py

示例3: list_variables

# 需要导入模块: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 别名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 别名]
def list_variables(checkpoint_dir):
  """See `tf.contrib.framework.list_variables`."""
  return checkpoint_utils.list_variables(checkpoint_dir) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:5,代码来源:checkpoints.py

示例4: load_checkpoints

# 需要导入模块: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 别名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 别名]
def load_checkpoints(sess, var_scopes = ('encoder', 'decoder', 'dense')):

  checkpoint_path =  config.lip_model_path
  if checkpoint_path:
    if os.path.isdir(checkpoint_path):
      checkpoint = tf.train.latest_checkpoint(checkpoint_path)
    else:
      checkpoint = checkpoint_path

  if config.featurizer:

    if checkpoint_path:
      from tensorflow.contrib.framework.python.framework import checkpoint_utils
      var_list = checkpoint_utils.list_variables(checkpoint)
      for var in var_list:
        if 'visual_frontend' in var[0]:
          var_scopes = var_scopes + ('visual_frontend',)
          break

    if not 'visual_frontend' in var_scopes:
      featurizer_vars = tf.global_variables(scope='visual_frontend')
      featurizer_ckpt = tf.train.get_checkpoint_state(config.featurizer_model_path)
      featurizer_vars = [var for var in featurizer_vars if not 'Adam' in var.name]
      tf.train.Saver(featurizer_vars).restore(sess, featurizer_ckpt.model_checkpoint_path)

  all_variables = []
  for scope in var_scopes:
    all_variables += [var for var in tf.global_variables(scope=scope)
                      if not 'Adam' in var.name ]
  if checkpoint_path:
    tf.train.Saver(all_variables).restore(sess, checkpoint)

    print("Restored saved model {}!".format(checkpoint)) 
开发者ID:afourast,项目名称:deep_lip_reading,代码行数:35,代码来源:main.py


注:本文中的tensorflow.contrib.framework.python.framework.checkpoint_utils.list_variables方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。