本文整理汇总了Python中tensorflow.contrib.framework.list_variables方法的典型用法代码示例。如果您正苦于以下问题:Python framework.list_variables方法的具体用法?Python framework.list_variables怎么用?Python framework.list_variables使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.framework
的用法示例。
在下文中一共展示了framework.list_variables方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_weights
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def get_weights(self, model_dir):
"""Returns weights per feature of the linear part.
Args:
model_dir: Directory where model parameters, graph and etc. are saved.
Returns:
The weights created by this model (without the optimizer weights).
"""
all_variables = [name for name, _ in list_variables(model_dir)]
values = {}
optimizer_regex = r".*/" + self._get_optimizer().get_name() + r"(_\d)?$"
for name in all_variables:
if (name.startswith(self._scope + "/") and
name != self._scope + "/bias_weight" and
not re.match(optimizer_regex, name)):
values[name] = load_variable(model_dir, name)
if len(values) == 1:
return values[list(values.keys())[0]]
return values
示例2: get_variable_names
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def get_variable_names(self):
"""Returns list of all variable names in this model.
Returns:
List of names.
"""
return [name for name, _ in list_variables(self.model_dir)]
示例3: _create_load_init_saver
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def _create_load_init_saver(self, filename):
if self.load != "":
return None
if len(glob.glob(self.model_dir + self.model + "-*.index")) > 0:
return None
if filename == "" or filename.endswith(".pickle") or filename.startswith("DeepLabRGB:"):
return None
vars_and_shapes_file = [x for x in list_variables(filename) if x[0] != "global_step"]
vars_file = [x[0] for x in vars_and_shapes_file]
vars_to_shapes_file = {x[0]: x[1] for x in vars_and_shapes_file}
vars_model = tf.global_variables()
assert all([x.name.endswith(":0") for x in vars_model])
vars_intersection = [x for x in vars_model if x.name[:-2] in vars_file]
vars_missing_in_graph = [x for x in vars_model if x.name[:-2] not in vars_file and "Adam" not in x.name and
"beta1_power" not in x.name and "beta2_power" not in x.name]
if len(vars_missing_in_graph) > 0:
print("the following variables will not be initialized since they are not present in the initialization model",
[v.name for v in vars_missing_in_graph], file=log.v1)
var_names_model = [x.name for x in vars_model]
vars_missing_in_file = [x for x in vars_file if x + ":0" not in var_names_model
and "RMSProp" not in x and "Adam" not in x and "Momentum" not in x]
if len(vars_missing_in_file) > 0:
print("the following variables will not be loaded from the file since they are not present in the graph",
vars_missing_in_file, file=log.v1)
vars_shape_mismatch = [x for x in vars_intersection if x.shape.as_list() != vars_to_shapes_file[x.name[:-2]]]
if len(vars_shape_mismatch) > 0:
print("the following variables will not be loaded from the file since the shapes in the graph and in the file "
"don't match:", [(x.name, x.shape) for x in vars_shape_mismatch if "Adam" not in x.name], file=log.v1)
vars_intersection = [x for x in vars_intersection if x not in vars_shape_mismatch]
return tf.train.Saver(var_list=vars_intersection)
示例4: get_variable_names
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def get_variable_names(self):
return [name for name, _ in list_variables(self._model_dir)]
示例5: weights_
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def weights_(self):
values = {}
optimizer_regex = r".*/"+self._optimizer.get_name() + r"(_\d)?$"
for name, _ in list_variables(self._model_dir):
if (name.startswith("linear/") and
name != "linear/bias_weight" and
not re.match(optimizer_regex, name)):
values[name] = load_variable(self._model_dir, name)
if len(values) == 1:
return values[list(values.keys())[0]]
return values
示例6: get_variable_names
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def get_variable_names(self):
"""Returns list of all variable names in this model.
Returns:
List of names.
"""
return [name for name, _ in list_variables(self._model_dir)]
示例7: available_variables_without_global_step
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def available_variables_without_global_step(checkpoint_dir):
import tensorflow.contrib.framework as tff
all_vars = tf.global_variables()
all_available_vars = tff.list_variables(checkpoint_dir=checkpoint_dir)
all_available_vars = dict(all_available_vars)
available_vars = []
for v in all_vars:
vname = v.name.split(':')[0]
if vname == 'global_step':
continue
if vname in all_available_vars and v.get_shape() == all_available_vars[vname]:
available_vars.append(v)
return available_vars
示例8: available_variables
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def available_variables(checkpoint_dir):
all_vars = tf.global_variables()
all_available_vars = tff.list_variables(checkpoint_dir=checkpoint_dir)
all_available_vars = dict(all_available_vars)
available_vars = []
for v in all_vars:
vname = v.name.split(':')[0]
if vname in all_available_vars and v.get_shape() == all_available_vars[vname]:
available_vars.append(v)
return available_vars
示例9: _create_load_init_saver
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def _create_load_init_saver(self, filename):
if self.load != "":
return None
if len(glob.glob(self.model_dir + self.model + "-*.index")) > 0:
return None
if filename == "" or filename.endswith(".pickle"):
return None
vars_and_shapes_file = [x for x in list_variables(filename) if x[0] != "global_step"]
vars_file = [x[0] for x in vars_and_shapes_file]
vars_to_shapes_file = {x[0]: x[1] for x in vars_and_shapes_file}
vars_model = tf.global_variables()
assert all([x.name.endswith(":0") for x in vars_model])
vars_intersection = [x for x in vars_model if x.name[:-2] in vars_file]
vars_missing_in_graph = [x for x in vars_model if x.name[:-2] not in vars_file and "Adam" not in x.name and
"beta1_power" not in x.name and "beta2_power" not in x.name]
if len(vars_missing_in_graph) > 0:
print("the following variables will not be initialized since they are not present in the " \
"initialization model", [v.name for v in vars_missing_in_graph])
var_names_model = [x.name for x in vars_model]
vars_missing_in_file = [x for x in vars_file if x + ":0" not in var_names_model
and "RMSProp" not in x and "Adam" not in x and "Momentum" not in x]
if len(vars_missing_in_file) > 0:
print("the following variables will not be loaded from the file since they are not present in the " \
"graph", vars_missing_in_file)
vars_shape_mismatch = [x for x in vars_intersection if x.shape.as_list() != vars_to_shapes_file[x.name[:-2]]]
if len(vars_shape_mismatch) > 0:
print("the following variables will not be loaded from the file since the shapes in the graph and in" \
" the file don't match:", [(x.name, x.shape) for x in vars_shape_mismatch
if "Adam" not in x.name])
vars_intersection = [x for x in vars_intersection if x not in vars_shape_mismatch]
return tf.train.Saver(var_list=vars_intersection)
示例10: _create_load_init_saver
# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def _create_load_init_saver(self, filename):
if self.load != "":
return None
if len(glob.glob(self.model_dir + self.model + "-*.index")) > 0:
return None
if filename == "" or filename.endswith(".pickle") or filename.startswith("DeepLabRGB:"):
return None
from tensorflow.contrib.framework import list_variables
vars_and_shapes_file = [x for x in list_variables(filename) if x[0] != "global_step"]
vars_file = [x[0] for x in vars_and_shapes_file]
vars_to_shapes_file = {x[0]: x[1] for x in vars_and_shapes_file}
vars_model = tf.global_variables()
assert all([x.name.endswith(":0") for x in vars_model])
vars_intersection = [x for x in vars_model if x.name[:-2] in vars_file]
vars_missing_in_graph = [x for x in vars_model if x.name[:-2] not in vars_file and "Adam" not in x.name and
"beta1_power" not in x.name and "beta2_power" not in x.name]
if len(vars_missing_in_graph) > 0:
print("the following variables will not be initialized since they are not present in the initialization model",
[v.name for v in vars_missing_in_graph], file=log.v1)
var_names_model = [x.name for x in vars_model]
vars_missing_in_file = [x for x in vars_file if x + ":0" not in var_names_model
and "RMSProp" not in x and "Adam" not in x and "Momentum" not in x]
if len(vars_missing_in_file) > 0:
print("the following variables will not be loaded from the file since they are not present in the graph",
vars_missing_in_file, file=log.v1)
vars_shape_mismatch = [x for x in vars_intersection if x.shape.as_list() != vars_to_shapes_file[x.name[:-2]]]
if len(vars_shape_mismatch) > 0:
print("the following variables will not be loaded from the file since the shapes in the graph and in the file "
"don't match:", [(x.name, x.shape) for x in vars_shape_mismatch if "Adam" not in x.name], file=log.v1)
vars_intersection = [x for x in vars_intersection if x not in vars_shape_mismatch]
return tf.train.Saver(var_list=vars_intersection)