当前位置: 首页>>代码示例>>Python>>正文


Python framework.list_variables方法代码示例

本文整理汇总了Python中tensorflow.contrib.framework.list_variables方法的典型用法代码示例。如果您正苦于以下问题:Python framework.list_variables方法的具体用法?Python framework.list_variables怎么用?Python framework.list_variables使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.framework的用法示例。


在下文中一共展示了framework.list_variables方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_weights

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def get_weights(self, model_dir):
    """Returns weights per feature of the linear part.

    Args:
      model_dir: Directory where model parameters, graph and etc. are saved.

    Returns:
      The weights created by this model (without the optimizer weights).
    """
    all_variables = [name for name, _ in list_variables(model_dir)]
    values = {}
    optimizer_regex = r".*/" + self._get_optimizer().get_name() + r"(_\d)?$"
    for name in all_variables:
      if (name.startswith(self._scope + "/") and
          name != self._scope + "/bias_weight" and
          not re.match(optimizer_regex, name)):
        values[name] = load_variable(model_dir, name)
    if len(values) == 1:
      return values[list(values.keys())[0]]
    return values 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:22,代码来源:composable_model.py

示例2: get_variable_names

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def get_variable_names(self):
    """Returns list of all variable names in this model.

    Returns:
      List of names.
    """
    return [name for name, _ in list_variables(self.model_dir)] 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:9,代码来源:estimator.py

示例3: _create_load_init_saver

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def _create_load_init_saver(self, filename):
    if self.load != "":
      return None
    if len(glob.glob(self.model_dir + self.model + "-*.index")) > 0:
      return None
    if filename == "" or filename.endswith(".pickle") or filename.startswith("DeepLabRGB:"):
      return None

    vars_and_shapes_file = [x for x in list_variables(filename) if x[0] != "global_step"]
    vars_file = [x[0] for x in vars_and_shapes_file]
    vars_to_shapes_file = {x[0]: x[1] for x in vars_and_shapes_file}
    vars_model = tf.global_variables()
    assert all([x.name.endswith(":0") for x in vars_model])
    vars_intersection = [x for x in vars_model if x.name[:-2] in vars_file]
    vars_missing_in_graph = [x for x in vars_model if x.name[:-2] not in vars_file and "Adam" not in x.name and
                             "beta1_power" not in x.name and "beta2_power" not in x.name]
    if len(vars_missing_in_graph) > 0:
      print("the following variables will not be initialized since they are not present in the initialization model",
            [v.name for v in vars_missing_in_graph], file=log.v1)

    var_names_model = [x.name for x in vars_model]
    vars_missing_in_file = [x for x in vars_file if x + ":0" not in var_names_model
                            and "RMSProp" not in x and "Adam" not in x and "Momentum" not in x]
    if len(vars_missing_in_file) > 0:
      print("the following variables will not be loaded from the file since they are not present in the graph",
            vars_missing_in_file, file=log.v1)

    vars_shape_mismatch = [x for x in vars_intersection if x.shape.as_list() != vars_to_shapes_file[x.name[:-2]]]
    if len(vars_shape_mismatch) > 0:
      print("the following variables will not be loaded from the file since the shapes in the graph and in the file "
            "don't match:", [(x.name, x.shape) for x in vars_shape_mismatch if "Adam" not in x.name], file=log.v1)
      vars_intersection = [x for x in vars_intersection if x not in vars_shape_mismatch]
    return tf.train.Saver(var_list=vars_intersection) 
开发者ID:tobiasfshr,项目名称:MOTSFusion,代码行数:35,代码来源:Saver.py

示例4: get_variable_names

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def get_variable_names(self):
    return [name for name, _ in list_variables(self._model_dir)] 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:4,代码来源:svm.py

示例5: weights_

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def weights_(self):
    values = {}
    optimizer_regex = r".*/"+self._optimizer.get_name() + r"(_\d)?$"
    for name, _ in list_variables(self._model_dir):
      if (name.startswith("linear/") and
          name != "linear/bias_weight" and
          not re.match(optimizer_regex, name)):
        values[name] = load_variable(self._model_dir, name)
    if len(values) == 1:
      return values[list(values.keys())[0]]
    return values 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:13,代码来源:svm.py

示例6: get_variable_names

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def get_variable_names(self):
    """Returns list of all variable names in this model.

    Returns:
      List of names.
    """
    return [name for name, _ in list_variables(self._model_dir)] 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:9,代码来源:dnn.py

示例7: available_variables_without_global_step

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def available_variables_without_global_step(checkpoint_dir):
    import tensorflow.contrib.framework as tff
    all_vars = tf.global_variables()
    all_available_vars = tff.list_variables(checkpoint_dir=checkpoint_dir)
    all_available_vars = dict(all_available_vars)
    available_vars = []
    for v in all_vars:
        vname = v.name.split(':')[0]
        if vname == 'global_step':
            continue
        if vname in all_available_vars and v.get_shape() == all_available_vars[vname]:
            available_vars.append(v)
    return available_vars 
开发者ID:shiyuzh2007,项目名称:ASR,代码行数:15,代码来源:pretrain_layerblock.py

示例8: available_variables

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def available_variables(checkpoint_dir):
    all_vars = tf.global_variables()
    all_available_vars = tff.list_variables(checkpoint_dir=checkpoint_dir)
    all_available_vars = dict(all_available_vars)
    available_vars = []
    for v in all_vars:
        vname = v.name.split(':')[0]
        if vname in all_available_vars and v.get_shape() == all_available_vars[vname]:
            available_vars.append(v)
    return available_vars 
开发者ID:shiyuzh2007,项目名称:ASR,代码行数:12,代码来源:utils.py

示例9: _create_load_init_saver

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def _create_load_init_saver(self, filename):
    if self.load != "":
      return None
    if len(glob.glob(self.model_dir + self.model + "-*.index")) > 0:
      return None
    if filename == "" or filename.endswith(".pickle"):
      return None

    vars_and_shapes_file = [x for x in list_variables(filename) if x[0] != "global_step"]
    vars_file = [x[0] for x in vars_and_shapes_file]
    vars_to_shapes_file = {x[0]: x[1] for x in vars_and_shapes_file}
    vars_model = tf.global_variables()
    assert all([x.name.endswith(":0") for x in vars_model])
    vars_intersection = [x for x in vars_model if x.name[:-2] in vars_file]
    vars_missing_in_graph = [x for x in vars_model if x.name[:-2] not in vars_file and "Adam" not in x.name and
                             "beta1_power" not in x.name and "beta2_power" not in x.name]
    if len(vars_missing_in_graph) > 0:
      print("the following variables will not be initialized since they are not present in the " \
                       "initialization model", [v.name for v in vars_missing_in_graph])

    var_names_model = [x.name for x in vars_model]
    vars_missing_in_file = [x for x in vars_file if x + ":0" not in var_names_model
                            and "RMSProp" not in x and "Adam" not in x and "Momentum" not in x]
    if len(vars_missing_in_file) > 0:
      print("the following variables will not be loaded from the file since they are not present in the " \
                       "graph", vars_missing_in_file)

    vars_shape_mismatch = [x for x in vars_intersection if x.shape.as_list() != vars_to_shapes_file[x.name[:-2]]]
    if len(vars_shape_mismatch) > 0:
      print("the following variables will not be loaded from the file since the shapes in the graph and in" \
                       " the file don't match:", [(x.name, x.shape) for x in vars_shape_mismatch
                                                  if "Adam" not in x.name])
      vars_intersection = [x for x in vars_intersection if x not in vars_shape_mismatch]
    return tf.train.Saver(var_list=vars_intersection) 
开发者ID:JonathonLuiten,项目名称:PReMVOS,代码行数:36,代码来源:Engine.py

示例10: _create_load_init_saver

# 需要导入模块: from tensorflow.contrib import framework [as 别名]
# 或者: from tensorflow.contrib.framework import list_variables [as 别名]
def _create_load_init_saver(self, filename):
    if self.load != "":
      return None
    if len(glob.glob(self.model_dir + self.model + "-*.index")) > 0:
      return None
    if filename == "" or filename.endswith(".pickle") or filename.startswith("DeepLabRGB:"):
      return None
    from tensorflow.contrib.framework import list_variables
    vars_and_shapes_file = [x for x in list_variables(filename) if x[0] != "global_step"]
    vars_file = [x[0] for x in vars_and_shapes_file]
    vars_to_shapes_file = {x[0]: x[1] for x in vars_and_shapes_file}
    vars_model = tf.global_variables()
    assert all([x.name.endswith(":0") for x in vars_model])
    vars_intersection = [x for x in vars_model if x.name[:-2] in vars_file]
    vars_missing_in_graph = [x for x in vars_model if x.name[:-2] not in vars_file and "Adam" not in x.name and
                             "beta1_power" not in x.name and "beta2_power" not in x.name]
    if len(vars_missing_in_graph) > 0:
      print("the following variables will not be initialized since they are not present in the initialization model",
            [v.name for v in vars_missing_in_graph], file=log.v1)

    var_names_model = [x.name for x in vars_model]
    vars_missing_in_file = [x for x in vars_file if x + ":0" not in var_names_model
                            and "RMSProp" not in x and "Adam" not in x and "Momentum" not in x]
    if len(vars_missing_in_file) > 0:
      print("the following variables will not be loaded from the file since they are not present in the graph",
            vars_missing_in_file, file=log.v1)

    vars_shape_mismatch = [x for x in vars_intersection if x.shape.as_list() != vars_to_shapes_file[x.name[:-2]]]
    if len(vars_shape_mismatch) > 0:
      print("the following variables will not be loaded from the file since the shapes in the graph and in the file "
            "don't match:", [(x.name, x.shape) for x in vars_shape_mismatch if "Adam" not in x.name], file=log.v1)
      vars_intersection = [x for x in vars_intersection if x not in vars_shape_mismatch]
    return tf.train.Saver(var_list=vars_intersection) 
开发者ID:VisualComputingInstitute,项目名称:TrackR-CNN,代码行数:35,代码来源:Saver.py


注:本文中的tensorflow.contrib.framework.list_variables方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。