当前位置: 首页>>代码示例>>Python>>正文


Python graph_util.extract_sub_graph函数代码示例

本文整理汇总了Python中tensorflow.python.framework.graph_util.extract_sub_graph函数的典型用法代码示例。如果您正苦于以下问题:Python extract_sub_graph函数的具体用法?Python extract_sub_graph怎么用?Python extract_sub_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了extract_sub_graph函数的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: strip_unused

def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(
            node.attr["_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def
开发者ID:marevol,项目名称:tensorflow,代码行数:34,代码来源:strip_unused_lib.py

示例2: testExtractSubGraph

  def testExtractSubGraph(self):
    graph_def = graph_pb2.GraphDef()
    n1 = graph_def.node.add()
    n1.name = "n1"
    n1.input.extend(["n5"])
    n2 = graph_def.node.add()
    n2.name = "n2"
    # Take the first output of the n1 node as the input.
    n2.input.extend(["n1:0"])
    n3 = graph_def.node.add()
    n3.name = "n3"
    # Add a control input (which isn't really needed by the kernel, but
    # rather to enforce execution order between nodes).
    n3.input.extend(["^n2"])
    n4 = graph_def.node.add()
    n4.name = "n4"

    # It is fine to have a loops in the graph as well.
    n5 = graph_def.node.add()
    n5.name = "n5"
    n5.input.extend(["n1"])

    sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
    self.assertEqual("n1", sub_graph.node[0].name)
    self.assertEqual("n2", sub_graph.node[1].name)
    self.assertEqual("n3", sub_graph.node[2].name)
    self.assertEqual("n5", sub_graph.node[3].name)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:27,代码来源:graph_util_test.py

示例3: test_remove_unneeded_nodes

  def test_remove_unneeded_nodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
                                                 [a_constant_name,
                                                  "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
                                                 [b_constant_name,
                                                  "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name,
                                           b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = quantize_graph.create_constant_node(a_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([a_constant])
    b_constant = quantize_graph.create_constant_node(b_constant_name,
                                                     value=1,
                                                     dtype=tf.float32,
                                                     shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_constant_name,
                                           b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
    output = rewriter.remove_unneeded_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
开发者ID:2020zyc,项目名称:tensorflow,代码行数:60,代码来源:quantize_graph_test.py

示例4: test_keep_control_edges

  def test_keep_control_edges(self):
    no_op_name = "no_op"
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = graph_pb2.GraphDef()
    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
    graph_def.node.extend([no_op])
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=1, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
                                              [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = quantize_graph.create_node(
        "Identity", a_identity_name,
        [a_constant_name, "^" + a_check_name, "^" + no_op_name])
    graph_def.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=1, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
                                              [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = quantize_graph.create_node(
        "Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name, b_identity_name])
    quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
    graph_def.node.extend([add_node])

    expected_output = graph_pb2.GraphDef()
    no_op = quantize_graph.create_node("NoOp", no_op_name, [])
    expected_output.node.extend([no_op])
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=1, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([a_constant])
    a_identity_node = quantize_graph.create_node(
        "Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
    expected_output.node.extend([a_identity_node])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=1, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([b_constant])
    add_node = quantize_graph.create_node("Add", add_name,
                                          [a_identity_name, b_constant_name])
    quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
    expected_output.node.extend([add_node])
    expected_output.versions.CopyFrom(graph_def.versions)
    expected_output.library.CopyFrom(graph_def.library)

    output = graph_util.remove_training_nodes(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [add_name])
    self.assertProtoEquals(expected_output, stripped_output)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:58,代码来源:quantize_graph_test.py

示例5: strip_unused

def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed.

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
  for name in input_node_names:
    if ":" in name:
      raise ValueError("Name '%s' appears to refer to a Tensor, "
                       "not a Operation." % name)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  not_found = {name for name in input_node_names}
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      not_found.remove(node.name)
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  if not_found:
    raise KeyError("The following input nodes were not found: %s\n" % not_found)

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def
开发者ID:1000sprites,项目名称:tensorflow,代码行数:56,代码来源:strip_unused_lib.py

示例6: __init__

    def __init__(self, meta_file, checkpoint_file, frozen_file, dest_nodes = None):
        super(TensorflowParser, self).__init__()

        # load model files into TensorFlow graph
        if meta_file:
            model = TensorflowParser._load_meta(meta_file)

        if checkpoint_file:
            self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
            self.weight_loaded = True

        if dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            model = extract_sub_graph(model, dest_nodes.split(','))

        # Build network graph
        self.tf_graph = TensorflowGraph(model)
        self.tf_graph.build()
开发者ID:ZhongxingPeng,项目名称:MMdnn,代码行数:18,代码来源:tensorflow_parser.py

示例7: __init__

    def __init__(self, input_args, dest_nodes = None):
        super(TensorflowParser, self).__init__()

        # load model files into Keras graph
        from six import string_types as _string_types
        if isinstance(input_args, _string_types):
            model = TensorflowParser._load_meta(input_args)
        elif isinstance(input_args, tuple):
            model = TensorflowParser._load_meta(input_args[0])
            self.ckpt_data = TensorflowParser._load_weights(input_args[1])
            self.weight_loaded = True

        if dest_nodes != None:
            from tensorflow.python.framework.graph_util import extract_sub_graph
            model = extract_sub_graph(model, dest_nodes.split(','))

        # Build network graph
        self.tf_graph =  TensorflowGraph(model)
        self.tf_graph.build()
开发者ID:zbxzc35,项目名称:MMdnn,代码行数:19,代码来源:tensorflow_parser.py

示例8: strip_unused

def strip_unused(input_graph, input_binary, output_graph, input_node_names,
                 output_node_names, placeholder_type_enum):
  """Removes unused nodes from a graph."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  input_node_names_list = input_node_names.split(",")
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names_list:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node))
开发者ID:2020zyc,项目名称:tensorflow,代码行数:42,代码来源:strip_unused.py

示例9: strip_pruning_vars_fn

def strip_pruning_vars_fn(input_graph_def, output_node_names):
  """Removes mask variable from the graph.

  Replaces the masked_weight tensor with element-wise multiplication of mask
  and the corresponding weight variable.

  Args:
    input_graph_def: A GraphDef in which the variables have been converted to
      constants. This is typically the output of
      tf.graph_util.convert_variables_to_constant()
    output_node_names: List of name strings for the result nodes of the graph

  Returns:
    A GraphDef in which pruning-related variables have been removed
  """
  masked_weights_dict = _get_masked_weights(input_graph_def)
  pruned_graph_def = graph_pb2.GraphDef()

  # Replace masked_weight with a const op containing the
  # result of tf.multiply(mask,weight)
  for node in input_graph_def.node:
    output_node = node_def_pb2.NodeDef()
    if 'masked_weight' in node.name:
      output_node.op = 'Const'
      output_node.name = node.name
      dtype = node.attr['T']
      data = masked_weights_dict[node.name]
      output_node.attr['dtype'].CopyFrom(dtype)
      output_node.attr['value'].CopyFrom(
          attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data)))

    else:
      output_node.CopyFrom(node)
    pruned_graph_def.node.extend([output_node])

  # Remove stranded nodes: mask and weights
  return graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:37,代码来源:strip_pruning_vars_lib.py

示例10: test_remove_redundant_quantization

  def test_remove_redundant_quantization(self):
    a_constant_name = "a_constant"
    a_constant_min_name = "a_constant_min"
    a_constant_max_name = "a_constant_max"
    a_dequantize_name = "a_dequantize"
    a_quantize_name = "a_quantize"
    b_constant_name = "b_constant"
    b_constant_min_name = "b_constant_min"
    b_constant_max_name = "b_constant_max"
    b_dequantize_name = "b_dequantize"
    b_quantize_name = "b_quantize"
    mat_mul_name = "mat_mul"
    graph_def = graph_pb2.GraphDef()
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
    graph_def.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(
        a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(
        a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([a_constant_max])
    a_dequantize_node = quantize_graph.create_node(
        "Dequantize", a_dequantize_name,
        [a_constant_name, a_constant_min_name, a_constant_max_name])
    quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8)
    graph_def.node.extend([a_dequantize_node])
    a_quantize_node = quantize_graph.create_node(
        "QuantizeV2", a_quantize_name,
        [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8)
    graph_def.node.extend([a_quantize_node])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
    graph_def.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(
        b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(
        b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
    graph_def.node.extend([b_constant_max])
    b_dequantize_node = quantize_graph.create_node(
        "Dequantize", b_dequantize_name,
        [b_constant_name, b_constant_min_name, b_constant_max_name])
    quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8)
    graph_def.node.extend([b_dequantize_node])
    b_quantize_node = quantize_graph.create_node(
        "QuantizeV2", b_quantize_name,
        [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"])
    quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8)
    graph_def.node.extend([b_quantize_node])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
        a_quantize_name, b_quantize_name, a_quantize_name + ":1",
        a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2"
    ])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
    graph_def.node.extend([mat_mul_node])

    expected_output = graph_pb2.GraphDef()
    a_constant = quantize_graph.create_constant_node(
        a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
    expected_output.node.extend([a_constant])
    a_constant_min = quantize_graph.create_constant_node(
        a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([a_constant_min])
    a_constant_max = quantize_graph.create_constant_node(
        a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([a_constant_max])
    b_constant = quantize_graph.create_constant_node(
        b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
    expected_output.node.extend([b_constant])
    b_constant_min = quantize_graph.create_constant_node(
        b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([b_constant_min])
    b_constant_max = quantize_graph.create_constant_node(
        b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
    expected_output.node.extend([b_constant_max])
    mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
        a_constant_name, b_constant_name, a_constant_min_name,
        a_constant_max_name, b_constant_min_name, b_constant_max_name
    ])
    quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
    quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
    expected_output.node.extend([mat_mul_node])
    expected_output.versions.CopyFrom(graph_def.versions)
    expected_output.library.CopyFrom(graph_def.library)

    rewriter = quantize_graph.GraphRewriter(
        graph_def, [mat_mul_name], quantized_input_range=None)
    output = rewriter.remove_redundant_quantization(graph_def)
    stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
    self.assertProtoEquals(expected_output, stripped_output)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:93,代码来源:quantize_graph_test.py

示例11: testExtractSubGraphWithInvalidDestNodes

 def testExtractSubGraphWithInvalidDestNodes(self):
   graph_def = graph_pb2.GraphDef()
   n1 = graph_def.node.add()
   n1.name = "n1"
   with self.assertRaisesRegexp(TypeError, "must be a list"):
     graph_util.extract_sub_graph(graph_def, "n1")
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:6,代码来源:graph_util_test.py

示例12: remove_dead_nodes

 def remove_dead_nodes(self, output_names):
   """Removes nodes that are no longer needed for inference from the graph."""
   old_output_graph = self.output_graph
   self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                    output_names)
开发者ID:frankfqchen,项目名称:tensorflow,代码行数:5,代码来源:quantize_graph.py


注:本文中的tensorflow.python.framework.graph_util.extract_sub_graph函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。