当前位置: 首页>>代码示例>>Python>>正文


Python graph_util.extract_sub_graph方法代码示例

本文整理汇总了Python中tensorflow.python.framework.graph_util.extract_sub_graph方法的典型用法代码示例。如果您正苦于以下问题:Python graph_util.extract_sub_graph方法的具体用法?Python graph_util.extract_sub_graph怎么用?Python graph_util.extract_sub_graph使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.framework.graph_util的用法示例。


在下文中一共展示了graph_util.extract_sub_graph方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: strip_unused

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 别名]
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:43,代码来源:strip_unused_lib.py

示例2: strip_unused

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 别名]
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(
            node.attr["_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:36,代码来源:strip_unused_lib.py

示例3: testExtractSubGraph

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 别名]
def testExtractSubGraph(self):
    graph_def = tf.GraphDef()
    n1 = graph_def.node.add()
    n1.name = "n1"
    n1.input.extend(["n5"])
    n2 = graph_def.node.add()
    n2.name = "n2"
    # Take the first output of the n1 node as the input.
    n2.input.extend(["n1:0"])
    n3 = graph_def.node.add()
    n3.name = "n3"
    # Add a control input (which isn't really needed by the kernel, but
    # rather to enforce execution order between nodes).
    n3.input.extend(["^n2"])
    n4 = graph_def.node.add()
    n4.name = "n4"

    # It is fine to have a loops in the graph as well.
    n5 = graph_def.node.add()
    n5.name = "n5"
    n5.input.extend(["n1"])

    sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
    self.assertEqual("n1", sub_graph.node[0].name)
    self.assertEqual("n2", sub_graph.node[1].name)
    self.assertEqual("n3", sub_graph.node[2].name)
    self.assertEqual("n5", sub_graph.node[3].name) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:29,代码来源:graph_util_test.py

示例4: remove_dead_nodes

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 别名]
def remove_dead_nodes(self, output_names):
    """Removes nodes that are no longer needed for inference from the graph."""
    old_output_graph = self.output_graph
    self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                     output_names) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:7,代码来源:quantize_graph.py

示例5: strip_meta_graph

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 别名]
def strip_meta_graph(meta_graph_def, node_names, var_names):
  node_names = node_names[:]
  collections = meta_graph_def.collection_def

  # Look for matching variable names and initializers and keep them too.
  var_def = variable_pb2.VariableDef()
  for var_col_name in ["variables", "trainable_variables"]:
    var_def_bs = collections[var_col_name].bytes_list.value
    for var_def_b in var_def_bs:
      var_def.ParseFromString(var_def_b)
      if var_def.variable_name not in var_names:
        # TODO(adamb) Should remove variable from collection.
        continue
      node_names.append(var_def.initializer_name)

  wc_def = control_flow_pb2.WhileContextDef()
  wc_values = collections["while_context"].bytes_list.value
  for wc_ix in range(len(wc_values) - 1, -1, -1):
    wc_bytes = wc_values[wc_ix]
    wc_def.ParseFromString(wc_bytes)
    unused = True
    wc_pivot_name = wc_def.pivot_name
    for name in node_names:
      if name.startswith(wc_pivot_name):
        unused = False
        break

    if unused:
      del wc_values[wc_ix]

  graph_def = meta_graph_def.graph_def
  eprint("only keeping", node_names, "from", [n.name for n in graph_def.node])
  graph_def = graph_util.extract_sub_graph(graph_def, node_names)
  meta_graph_def.graph_def.CopyFrom(graph_def) 
开发者ID:tensorlang,项目名称:tensorlang,代码行数:36,代码来源:graph_xform.py

示例6: strip_unused

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 别名]
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed.

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
  for name in input_node_names:
    if ":" in name:
      raise ValueError("Name '%s' appears to refer to a Tensor, "
                       "not a Operation." % name)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  not_found = {name for name in input_node_names}
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      not_found.remove(node.name)
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  if not_found:
    raise KeyError("The following input nodes were not found: %s\n" % not_found)

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:58,代码来源:strip_unused_lib.py

示例7: strip_unused

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 别名]
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
                 output_node_names, placeholder_type_enum):
  """Removes unused nodes from a graph."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  input_node_names_list = input_node_names.split(",")
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names_list:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node)) 
开发者ID:yselivonchyk,项目名称:TensorFlow_DCIGN,代码行数:44,代码来源:strip_unused.py


注:本文中的tensorflow.python.framework.graph_util.extract_sub_graph方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。