當前位置: 首頁>>代碼示例>>Python>>正文


Python checkpoint_utils.list_variables方法代碼示例

本文整理匯總了Python中tensorflow.contrib.framework.python.framework.checkpoint_utils.list_variables方法的典型用法代碼示例。如果您正苦於以下問題:Python checkpoint_utils.list_variables方法的具體用法?Python checkpoint_utils.list_variables怎麽用?Python checkpoint_utils.list_variables使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow.contrib.framework.python.framework.checkpoint_utils的用法示例。


在下文中一共展示了checkpoint_utils.list_variables方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: print_tensors_in_checkpoint_file

# 需要導入模塊: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 別名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 別名]
def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    if not tensor_name:
      variables = checkpoint_utils.list_variables(file_name)
      for name, shape in variables:
        print("%s\t%s" % (name, str(shape)))
    else:
      print("tensor_name: ", tensor_name)
      print(checkpoint_utils.load_variable(file_name, tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.") 
開發者ID:ryfeus,項目名稱:lambda-packs,代碼行數:27,代碼來源:inspect_checkpoint.py

示例2: main

# 需要導入模塊: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 別名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 別名]
def main(argv):
    var_list = checkpoint_utils.list_variables(FLAGS.path)
    with open(FLAGS.json + ".json", 'w', encoding="utf-8") as f:
        json.dump(var_list, f, ensure_ascii=False, indent=4)
    for v in var_list:
        print(v) 
開發者ID:uizard-technologies,項目名稱:realmix,代碼行數:8,代碼來源:inspect_variables.py

示例3: list_variables

# 需要導入模塊: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 別名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 別名]
def list_variables(checkpoint_dir):
  """See `tf.contrib.framework.list_variables`."""
  return checkpoint_utils.list_variables(checkpoint_dir) 
開發者ID:tobegit3hub,項目名稱:deep_image_model,代碼行數:5,代碼來源:checkpoints.py

示例4: load_checkpoints

# 需要導入模塊: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 別名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 別名]
def load_checkpoints(sess, var_scopes = ('encoder', 'decoder', 'dense')):

  checkpoint_path =  config.lip_model_path
  if checkpoint_path:
    if os.path.isdir(checkpoint_path):
      checkpoint = tf.train.latest_checkpoint(checkpoint_path)
    else:
      checkpoint = checkpoint_path

  if config.featurizer:

    if checkpoint_path:
      from tensorflow.contrib.framework.python.framework import checkpoint_utils
      var_list = checkpoint_utils.list_variables(checkpoint)
      for var in var_list:
        if 'visual_frontend' in var[0]:
          var_scopes = var_scopes + ('visual_frontend',)
          break

    if not 'visual_frontend' in var_scopes:
      featurizer_vars = tf.global_variables(scope='visual_frontend')
      featurizer_ckpt = tf.train.get_checkpoint_state(config.featurizer_model_path)
      featurizer_vars = [var for var in featurizer_vars if not 'Adam' in var.name]
      tf.train.Saver(featurizer_vars).restore(sess, featurizer_ckpt.model_checkpoint_path)

  all_variables = []
  for scope in var_scopes:
    all_variables += [var for var in tf.global_variables(scope=scope)
                      if not 'Adam' in var.name ]
  if checkpoint_path:
    tf.train.Saver(all_variables).restore(sess, checkpoint)

    print("Restored saved model {}!".format(checkpoint)) 
開發者ID:afourast,項目名稱:deep_lip_reading,代碼行數:35,代碼來源:main.py


注:本文中的tensorflow.contrib.framework.python.framework.checkpoint_utils.list_variables方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。