本文整理汇总了Python中tensorflow.contrib.framework.python.framework.checkpoint_utils.list_variables方法的典型用法代码示例。如果您正苦于以下问题:Python checkpoint_utils.list_variables方法的具体用法?Python checkpoint_utils.list_variables怎么用?Python checkpoint_utils.list_variables使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.framework.python.framework.checkpoint_utils
的用法示例。
在下文中一共展示了checkpoint_utils.list_variables方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: print_tensors_in_checkpoint_file
# 需要导入模块: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 别名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 别名]
def print_tensors_in_checkpoint_file(file_name, tensor_name):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
in the checkpoint file.
If `tensor_name` is provided, prints the content of the tensor.
Args:
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
"""
try:
if not tensor_name:
variables = checkpoint_utils.list_variables(file_name)
for name, shape in variables:
print("%s\t%s" % (name, str(shape)))
else:
print("tensor_name: ", tensor_name)
print(checkpoint_utils.load_variable(file_name, tensor_name))
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
示例2: main
# 需要导入模块: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 别名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 别名]
def main(argv):
var_list = checkpoint_utils.list_variables(FLAGS.path)
with open(FLAGS.json + ".json", 'w', encoding="utf-8") as f:
json.dump(var_list, f, ensure_ascii=False, indent=4)
for v in var_list:
print(v)
示例3: list_variables
# 需要导入模块: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 别名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 别名]
def list_variables(checkpoint_dir):
"""See `tf.contrib.framework.list_variables`."""
return checkpoint_utils.list_variables(checkpoint_dir)
示例4: load_checkpoints
# 需要导入模块: from tensorflow.contrib.framework.python.framework import checkpoint_utils [as 别名]
# 或者: from tensorflow.contrib.framework.python.framework.checkpoint_utils import list_variables [as 别名]
def load_checkpoints(sess, var_scopes = ('encoder', 'decoder', 'dense')):
checkpoint_path = config.lip_model_path
if checkpoint_path:
if os.path.isdir(checkpoint_path):
checkpoint = tf.train.latest_checkpoint(checkpoint_path)
else:
checkpoint = checkpoint_path
if config.featurizer:
if checkpoint_path:
from tensorflow.contrib.framework.python.framework import checkpoint_utils
var_list = checkpoint_utils.list_variables(checkpoint)
for var in var_list:
if 'visual_frontend' in var[0]:
var_scopes = var_scopes + ('visual_frontend',)
break
if not 'visual_frontend' in var_scopes:
featurizer_vars = tf.global_variables(scope='visual_frontend')
featurizer_ckpt = tf.train.get_checkpoint_state(config.featurizer_model_path)
featurizer_vars = [var for var in featurizer_vars if not 'Adam' in var.name]
tf.train.Saver(featurizer_vars).restore(sess, featurizer_ckpt.model_checkpoint_path)
all_variables = []
for scope in var_scopes:
all_variables += [var for var in tf.global_variables(scope=scope)
if not 'Adam' in var.name ]
if checkpoint_path:
tf.train.Saver(all_variables).restore(sess, checkpoint)
print("Restored saved model {}!".format(checkpoint))