當前位置: 首頁>>代碼示例>>Python>>正文


Python framework.list_variables方法代碼示例

本文整理匯總了Python中tensorflow.contrib.framework.list_variables方法的典型用法代碼示例。如果您正苦於以下問題:Python framework.list_variables方法的具體用法?Python framework.list_variables怎麽用?Python framework.list_variables使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在tensorflow.contrib.framework的用法示例。


在下文中一共展示了framework.list_variables方法的10個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_weights

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def get_weights(self, model_dir):
    """Returns weights per feature of the linear part.

    Args:
      model_dir: Directory where model parameters, graph and etc. are saved.

    Returns:
      The weights created by this model (without the optimizer weights).
    """
    all_variables = [name for name, _ in list_variables(model_dir)]
    values = {}
    optimizer_regex = r".*/" + self._get_optimizer().get_name() + r"(_\d)?$"
    for name in all_variables:
      if (name.startswith(self._scope + "/") and
          name != self._scope + "/bias_weight" and
          not re.match(optimizer_regex, name)):
        values[name] = load_variable(model_dir, name)
    if len(values) == 1:
      return values[list(values.keys())[0]]
    return values 
開發者ID:ryfeus,項目名稱:lambda-packs,代碼行數:22,代碼來源:composable_model.py

示例2: get_variable_names

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def get_variable_names(self):
    """Returns list of all variable names in this model.

    Returns:
      List of names.
    """
    return [name for name, _ in list_variables(self.model_dir)] 
開發者ID:ryfeus,項目名稱:lambda-packs,代碼行數:9,代碼來源:estimator.py

示例3: _create_load_init_saver

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def _create_load_init_saver(self, filename):
    if self.load != "":
      return None
    if len(glob.glob(self.model_dir + self.model + "-*.index")) > 0:
      return None
    if filename == "" or filename.endswith(".pickle") or filename.startswith("DeepLabRGB:"):
      return None

    vars_and_shapes_file = [x for x in list_variables(filename) if x[0] != "global_step"]
    vars_file = [x[0] for x in vars_and_shapes_file]
    vars_to_shapes_file = {x[0]: x[1] for x in vars_and_shapes_file}
    vars_model = tf.global_variables()
    assert all([x.name.endswith(":0") for x in vars_model])
    vars_intersection = [x for x in vars_model if x.name[:-2] in vars_file]
    vars_missing_in_graph = [x for x in vars_model if x.name[:-2] not in vars_file and "Adam" not in x.name and
                             "beta1_power" not in x.name and "beta2_power" not in x.name]
    if len(vars_missing_in_graph) > 0:
      print("the following variables will not be initialized since they are not present in the initialization model",
            [v.name for v in vars_missing_in_graph], file=log.v1)

    var_names_model = [x.name for x in vars_model]
    vars_missing_in_file = [x for x in vars_file if x + ":0" not in var_names_model
                            and "RMSProp" not in x and "Adam" not in x and "Momentum" not in x]
    if len(vars_missing_in_file) > 0:
      print("the following variables will not be loaded from the file since they are not present in the graph",
            vars_missing_in_file, file=log.v1)

    vars_shape_mismatch = [x for x in vars_intersection if x.shape.as_list() != vars_to_shapes_file[x.name[:-2]]]
    if len(vars_shape_mismatch) > 0:
      print("the following variables will not be loaded from the file since the shapes in the graph and in the file "
            "don't match:", [(x.name, x.shape) for x in vars_shape_mismatch if "Adam" not in x.name], file=log.v1)
      vars_intersection = [x for x in vars_intersection if x not in vars_shape_mismatch]
    return tf.train.Saver(var_list=vars_intersection) 
開發者ID:tobiasfshr,項目名稱:MOTSFusion,代碼行數:35,代碼來源:Saver.py

示例4: get_variable_names

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def get_variable_names(self):
    return [name for name, _ in list_variables(self._model_dir)] 
開發者ID:tobegit3hub,項目名稱:deep_image_model,代碼行數:4,代碼來源:svm.py

示例5: weights_

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def weights_(self):
    values = {}
    optimizer_regex = r".*/"+self._optimizer.get_name() + r"(_\d)?$"
    for name, _ in list_variables(self._model_dir):
      if (name.startswith("linear/") and
          name != "linear/bias_weight" and
          not re.match(optimizer_regex, name)):
        values[name] = load_variable(self._model_dir, name)
    if len(values) == 1:
      return values[list(values.keys())[0]]
    return values 
開發者ID:tobegit3hub,項目名稱:deep_image_model,代碼行數:13,代碼來源:svm.py

示例6: get_variable_names

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def get_variable_names(self):
    """Returns list of all variable names in this model.

    Returns:
      List of names.
    """
    return [name for name, _ in list_variables(self._model_dir)] 
開發者ID:tobegit3hub,項目名稱:deep_image_model,代碼行數:9,代碼來源:dnn.py

示例7: available_variables_without_global_step

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def available_variables_without_global_step(checkpoint_dir):
    import tensorflow.contrib.framework as tff
    all_vars = tf.global_variables()
    all_available_vars = tff.list_variables(checkpoint_dir=checkpoint_dir)
    all_available_vars = dict(all_available_vars)
    available_vars = []
    for v in all_vars:
        vname = v.name.split(':')[0]
        if vname == 'global_step':
            continue
        if vname in all_available_vars and v.get_shape() == all_available_vars[vname]:
            available_vars.append(v)
    return available_vars 
開發者ID:shiyuzh2007,項目名稱:ASR,代碼行數:15,代碼來源:pretrain_layerblock.py

示例8: available_variables

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def available_variables(checkpoint_dir):
    all_vars = tf.global_variables()
    all_available_vars = tff.list_variables(checkpoint_dir=checkpoint_dir)
    all_available_vars = dict(all_available_vars)
    available_vars = []
    for v in all_vars:
        vname = v.name.split(':')[0]
        if vname in all_available_vars and v.get_shape() == all_available_vars[vname]:
            available_vars.append(v)
    return available_vars 
開發者ID:shiyuzh2007,項目名稱:ASR,代碼行數:12,代碼來源:utils.py

示例9: _create_load_init_saver

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def _create_load_init_saver(self, filename):
    if self.load != "":
      return None
    if len(glob.glob(self.model_dir + self.model + "-*.index")) > 0:
      return None
    if filename == "" or filename.endswith(".pickle"):
      return None

    vars_and_shapes_file = [x for x in list_variables(filename) if x[0] != "global_step"]
    vars_file = [x[0] for x in vars_and_shapes_file]
    vars_to_shapes_file = {x[0]: x[1] for x in vars_and_shapes_file}
    vars_model = tf.global_variables()
    assert all([x.name.endswith(":0") for x in vars_model])
    vars_intersection = [x for x in vars_model if x.name[:-2] in vars_file]
    vars_missing_in_graph = [x for x in vars_model if x.name[:-2] not in vars_file and "Adam" not in x.name and
                             "beta1_power" not in x.name and "beta2_power" not in x.name]
    if len(vars_missing_in_graph) > 0:
      print("the following variables will not be initialized since they are not present in the " \
                       "initialization model", [v.name for v in vars_missing_in_graph])

    var_names_model = [x.name for x in vars_model]
    vars_missing_in_file = [x for x in vars_file if x + ":0" not in var_names_model
                            and "RMSProp" not in x and "Adam" not in x and "Momentum" not in x]
    if len(vars_missing_in_file) > 0:
      print("the following variables will not be loaded from the file since they are not present in the " \
                       "graph", vars_missing_in_file)

    vars_shape_mismatch = [x for x in vars_intersection if x.shape.as_list() != vars_to_shapes_file[x.name[:-2]]]
    if len(vars_shape_mismatch) > 0:
      print("the following variables will not be loaded from the file since the shapes in the graph and in" \
                       " the file don't match:", [(x.name, x.shape) for x in vars_shape_mismatch
                                                  if "Adam" not in x.name])
      vars_intersection = [x for x in vars_intersection if x not in vars_shape_mismatch]
    return tf.train.Saver(var_list=vars_intersection) 
開發者ID:JonathonLuiten,項目名稱:PReMVOS,代碼行數:36,代碼來源:Engine.py

示例10: _create_load_init_saver

# 需要導入模塊: from tensorflow.contrib import framework [as 別名]
# 或者: from tensorflow.contrib.framework import list_variables [as 別名]
def _create_load_init_saver(self, filename):
    if self.load != "":
      return None
    if len(glob.glob(self.model_dir + self.model + "-*.index")) > 0:
      return None
    if filename == "" or filename.endswith(".pickle") or filename.startswith("DeepLabRGB:"):
      return None
    from tensorflow.contrib.framework import list_variables
    vars_and_shapes_file = [x for x in list_variables(filename) if x[0] != "global_step"]
    vars_file = [x[0] for x in vars_and_shapes_file]
    vars_to_shapes_file = {x[0]: x[1] for x in vars_and_shapes_file}
    vars_model = tf.global_variables()
    assert all([x.name.endswith(":0") for x in vars_model])
    vars_intersection = [x for x in vars_model if x.name[:-2] in vars_file]
    vars_missing_in_graph = [x for x in vars_model if x.name[:-2] not in vars_file and "Adam" not in x.name and
                             "beta1_power" not in x.name and "beta2_power" not in x.name]
    if len(vars_missing_in_graph) > 0:
      print("the following variables will not be initialized since they are not present in the initialization model",
            [v.name for v in vars_missing_in_graph], file=log.v1)

    var_names_model = [x.name for x in vars_model]
    vars_missing_in_file = [x for x in vars_file if x + ":0" not in var_names_model
                            and "RMSProp" not in x and "Adam" not in x and "Momentum" not in x]
    if len(vars_missing_in_file) > 0:
      print("the following variables will not be loaded from the file since they are not present in the graph",
            vars_missing_in_file, file=log.v1)

    vars_shape_mismatch = [x for x in vars_intersection if x.shape.as_list() != vars_to_shapes_file[x.name[:-2]]]
    if len(vars_shape_mismatch) > 0:
      print("the following variables will not be loaded from the file since the shapes in the graph and in the file "
            "don't match:", [(x.name, x.shape) for x in vars_shape_mismatch if "Adam" not in x.name], file=log.v1)
      vars_intersection = [x for x in vars_intersection if x not in vars_shape_mismatch]
    return tf.train.Saver(var_list=vars_intersection) 
開發者ID:VisualComputingInstitute,項目名稱:TrackR-CNN,代碼行數:35,代碼來源:Saver.py


注:本文中的tensorflow.contrib.framework.list_variables方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。