本文整理汇总了Python中tensorflow.contrib.learn.python.learn.utils.checkpoints.list_variables函数的典型用法代码示例。如果您正苦于以下问题:Python list_variables函数的具体用法?Python list_variables怎么用?Python list_variables使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了list_variables函数的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_variable_names
def get_variable_names(self):
"""Returns list of all variable names in this model.
Returns:
List of names.
"""
return [name for name, _ in checkpoints.list_variables(self.model_dir)]
示例2: print_tensors_in_checkpoint_file
def print_tensors_in_checkpoint_file(file_name, tensor_name):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
in the checkpoint file.
If `tensor_name` is provided, prints the content of the tensor.
Args:
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
"""
try:
if not tensor_name:
variables = checkpoints.list_variables(file_name)
for name, shape in variables:
print("%s\t%s" % (name, str(shape)))
else:
print("tensor_name: ", tensor_name)
print(checkpoints.load_variable(file_name, tensor_name))
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
示例3: testGetAllVariables
def testGetAllVariables(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
_create_checkpoints(session, checkpoint_dir)
self.assertEqual(checkpoints.list_variables(checkpoint_dir),
[("useful_scope/var4", [9, 9]),
("var1", [1, 10]),
("var2", [10, 10]),
("var3", [100, 100])])
示例4: weights_
def weights_(self):
values = {}
optimizer_regex = r".*/" + self._optimizer.get_name() + r"(_\d)?$"
for name, _ in checkpoints.list_variables(self._model_dir):
if name.startswith("linear/") and name != "linear/bias_weight" and not re.match(optimizer_regex, name):
values[name] = checkpoints.load_variable(self._model_dir, name)
if len(values) == 1:
return values[list(values.keys())[0]]
return values
示例5: get_weights
def get_weights(self, model_dir):
"""Returns weights per feature of the linear part.
Args:
model_dir: Directory where model parameters, graph and etc. are saved.
Returns:
The weights created by this model (without the optimizer weights).
"""
all_variables = [name for name, _ in checkpoints.list_variables(model_dir)]
values = {}
optimizer_regex = r".*/" + self._get_optimizer().get_name() + r"(_\d)?$"
for name in all_variables:
if (name.startswith(self._scope + "/") and
name != self._scope + "/bias_weight" and
not re.match(optimizer_regex, name)):
values[name] = checkpoints.load_variable(model_dir, name)
if len(values) == 1:
return values[list(values.keys())[0]]
return values
示例6: get_variable_names
def get_variable_names(self):
return [name for name, _ in checkpoints.list_variables(self._model_dir)]