当前位置: 首页>>代码示例>>Python>>正文


Python graph_util.remove_training_nodes方法代码示例

本文整理汇总了Python中tensorflow.python.framework.graph_util.remove_training_nodes方法的典型用法代码示例。如果您正苦于以下问题:Python graph_util.remove_training_nodes方法的具体用法?Python graph_util.remove_training_nodes怎么用?Python graph_util.remove_training_nodes使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.framework.graph_util的用法示例。


在下文中一共展示了graph_util.remove_training_nodes方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: export_cnn

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def export_cnn() -> None:
    input = tf.placeholder(tf.float32, shape=(1, 1, 3, 3))
    filter = tf.constant(np.ones((3, 3, 1, 1)), dtype=tf.float32)
    x = tf.nn.conv2d(input, filter, (1, 1, 1, 1), "SAME", data_format="NCHW")
    x = tf.nn.sigmoid(x)
    x = tf.nn.relu(x)

    pred_node_names = ["output"]
    tf.identity(x, name=pred_node_names[0])

    with tf.Session() as sess:
        constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), pred_node_names
        )

    frozen = graph_util.remove_training_nodes(constant_graph)

    output = "cnn.pb"
    graph_io.write_graph(frozen, ".", output, as_text=False) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:21,代码来源:convert.py

示例2: export

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def export(x: tf.Tensor, filename: str, sess=None):
    should_close = False
    if sess is None:
        should_close = True
        sess = tf.Session()

    pred_node_names = ["output"]
    tf.identity(x, name=pred_node_names[0])
    graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), pred_node_names
    )

    graph = graph_util.remove_training_nodes(graph)

    path = graph_io.write_graph(graph, ".", filename, as_text=False)

    if should_close:
        sess.close()

    return path 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:22,代码来源:convert_test.py

示例3: _load_saved_model

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def _load_saved_model(self):
        """Load the tensorflow saved model."""
        try:
            from tensorflow.python.tools import freeze_graph
            from tensorflow.python.framework import ops
            from tensorflow.python.framework import graph_util
            from tensorflow.core.framework import graph_pb2
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import tensorflow which is "
                "required to restore from saved model.")

        saved_model_dir = self._model_dir
        output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
        input_saved_model_dir = saved_model_dir
        output_node_names = self._get_output_names()

        input_binary = False
        input_saver_def_path = False
        restore_op_name = None
        filename_tensor_name = None
        clear_devices = True
        input_meta_graph = False
        checkpoint_path = None
        input_graph_filename = None
        saved_model_tags = ",".join(self._get_tag_set())

        freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
                                  input_binary, checkpoint_path, output_node_names,
                                  restore_op_name, filename_tensor_name,
                                  output_graph_filename, clear_devices, "", "", "",
                                  input_meta_graph, input_saved_model_dir,
                                  saved_model_tags)

        with ops.Graph().as_default():
            output_graph_def = graph_pb2.GraphDef()
            with open(output_graph_filename, "rb") as f:
                output_graph_def.ParseFromString(f.read())
            output_graph_def = graph_util.remove_training_nodes(output_graph_def,
                                                                protected_nodes=self._outputs)
            return output_graph_def 
开发者ID:apache,项目名称:incubator-tvm,代码行数:43,代码来源:tensorflow_parser.py

示例4: optimize_for_inference

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
                           placeholder_type_enum):
  """Applies a series of inference optimizations on the input graph.

  Args:
    input_graph_def: A GraphDef containing a training model.
    input_node_names: A list of names of the nodes that are fed inputs during
      inference.
    output_node_names: A list of names of the nodes that produce the final
      results.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    An optimized version of the input graph.
  """
  ensure_graph_is_valid(input_graph_def)
  optimized_graph_def = input_graph_def
  optimized_graph_def = strip_unused_lib.strip_unused(optimized_graph_def,
                                                      input_node_names,
                                                      output_node_names,
                                                      placeholder_type_enum)
  optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def)
  optimized_graph_def = fold_batch_norms(optimized_graph_def)
  optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
                                             output_node_names)
  ensure_graph_is_valid(optimized_graph_def)
  return optimized_graph_def 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:30,代码来源:optimize_for_inference_lib.py

示例5: optimize_for_inference

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def optimize_for_inference(input_graph_def, input_node_names,
                           output_node_names, placeholder_type_enum):
  """Applies a series of inference optimizations on the input graph.

  Args:
    input_graph_def: A GraphDef containing a training model.
    input_node_names: A list of names of the nodes that are fed inputs during
      inference.
    output_node_names: A list of names of the nodes that produce the final
      results.
    placeholder_type_enum: Data type of the placeholders used for inputs.

  Returns:
    An optimized version of the input graph.
  """
  ensure_graph_is_valid(input_graph_def)
  optimized_graph_def = input_graph_def
  optimized_graph_def = strip_unused_lib.strip_unused(optimized_graph_def,
                                                      input_node_names,
                                                      output_node_names,
                                                      placeholder_type_enum)
  optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def)
  optimized_graph_def = fold_batch_norms(optimized_graph_def)
  optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
                                             output_node_names)
  ensure_graph_is_valid(optimized_graph_def)
  return optimized_graph_def 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:29,代码来源:optimize_for_inference_lib.py

示例6: export_to_pb

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def export_to_pb(sess, x, filename):
    pred_names = ["output"]
    tf.identity(x, name=pred_names[0])

    graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), pred_names
    )

    graph = graph_util.remove_training_nodes(graph)
    path = graph_io.write_graph(graph, ".", filename, as_text=False)
    print("saved the frozen graph (ready for inference) at: ", path) 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:13,代码来源:main.py

示例7: export_to_pb

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def export_to_pb(sess, x, filename):
    pred_names = ["output"]
    tf.identity(x, name=pred_names[0])

    graph = graph_util.convert_variables_to_constants(
        sess, sess.graph.as_graph_def(), pred_names
    )

    graph = graph_util.remove_training_nodes(graph)
    path = graph_io.write_graph(graph, ".", filename, as_text=False)
    print("saved the frozen graph (ready for inference) at: ", filename)

    return path 
开发者ID:tf-encrypted,项目名称:tf-encrypted,代码行数:15,代码来源:mnist_deep_cnn.py

示例8: optimize_for_inference

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
                           placeholder_type_enum):
  """Applies a series of inference optimizations on the input graph.

  Args:
    input_graph_def: A GraphDef containing a training model.
    input_node_names: A list of names of the nodes that are fed inputs during
      inference.
    output_node_names: A list of names of the nodes that produce the final
      results.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    An optimized version of the input graph.
  """
  ensure_graph_is_valid(input_graph_def)
  optimized_graph_def = input_graph_def
  optimized_graph_def = strip_unused_lib.strip_unused(
      optimized_graph_def, input_node_names, output_node_names,
      placeholder_type_enum)
  optimized_graph_def = graph_util.remove_training_nodes(
      optimized_graph_def, output_node_names)
  optimized_graph_def = fold_batch_norms(optimized_graph_def)
  optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
                                             output_node_names)
  ensure_graph_is_valid(optimized_graph_def)
  return optimized_graph_def 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:30,代码来源:optimize_for_inference_lib.py

示例9: __init__

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def __init__(self, model, session = None):
		"""
		This constructor takes a reference to a TensorFlow Operation or Tensor or Keras model and then applies the two TensorFlow functions
		graph_util.convert_variables_to_constants and graph_util.remove_training_nodes to cleanse the graph of any nodes that are linked to training. This leaves us with 
		the nodes you need for inference. 
		In the resulting graph there should only be tf.Operations left that have one of the following types [Const, MatMul, Add, BiasAdd, Conv2D, Reshape, MaxPool, AveragePool, Placeholder, Relu, Sigmoid, Tanh]
		If the input should be a Keras model we will ignore operations with type Pack, Shape, StridedSlice, and Prod such that the Flatten layer can be used.
		
		Arguments
		---------
		model : tensorflow.Tensor or tensorflow.Operation or tensorflow.python.keras.engine.sequential.Sequential or keras.engine.sequential.Sequential
		    if tensorflow.Tensor: model.op will be treated as the output node of the TensorFlow model. Make sure that the graph only contains supported operations after applying
		                          graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [model.op.name] as output_node_names
		    if tensorflow.Operation: model will be treated as the output of the TensorFlow model. Make sure that the graph only contains supported operations after applying
		                          graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [model.op.name] as output_node_names
		    if tensorflow.python.keras.engine.sequential.Sequential: x = model.layers[-1].output.op.inputs[0].op will be treated as the output node of the Keras model. Make sure that the graph only
		                          contains supported operations after applying graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [x.name] as
		                          output_node_names
		    if keras.engine.sequential.Sequential: x = model.layers[-1].output.op.inputs[0].op will be treated as the output node of the Keras model. Make sure that the graph only
		                          contains supported operations after applying graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [x.name] as
		                          output_node_names
		session : tf.Session
		    session which contains the information about the trained variables. If None the code will take the Session from tf.get_default_session(). If you pass a keras model you don't have to
		    provide a session, this function will automatically get it.
		"""	
		output_names = None
		if issubclass(model.__class__, tf.Tensor):
			output_names = [model.op.name]
		elif issubclass(model.__class__, tf.Operation):
			output_names = [model.name]
		elif issubclass(model.__class__, Sequential):
			session      = tf.keras.backend.get_session()
			output_names = [model.layers[-1].output.op.inputs[0].op.name]
			model        = model.layers[-1].output.op
		elif issubclass(model.__class__, onnx.ModelProto):
			assert 0, 'not tensorflow model'
		else:
			import keras
			if issubclass(model.__class__, keras.engine.sequential.Sequential):
				session      = keras.backend.get_session()
				output_names = [model.layers[-1].output.op.inputs[0].op.name]
				model        = model.layers[-1].output.op
			else:
				assert 0, "ERAN can't recognize this input"
		
		if session is None:
			session = tf.get_default_session()
		
		tmp = graph_util.convert_variables_to_constants(session, model.graph.as_graph_def(), output_names)
		self.graph_def = graph_util.remove_training_nodes(tmp) 
开发者ID:eth-sri,项目名称:eran,代码行数:52,代码来源:tensorflow_translator.py

示例10: testRemoveTrainingNodes

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def testRemoveTrainingNodes(self):
    a_constant_name = "a_constant"
    b_constant_name = "b_constant"
    a_check_name = "a_check"
    b_check_name = "b_check"
    a_identity_name = "a_identity"
    b_identity_name = "b_identity"
    add_name = "add"
    graph_def = tf.GraphDef()
    a_constant = self.create_constant_node_def(a_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    graph_def.node.extend([a_constant])
    a_check_node = self.create_node_def("CheckNumerics", a_check_name,
                                        [a_constant_name])
    graph_def.node.extend([a_check_node])
    a_identity_node = self.create_node_def("Identity", a_identity_name,
                                           [a_constant_name,
                                            "^" + a_check_name])
    graph_def.node.extend([a_identity_node])
    b_constant = self.create_constant_node_def(b_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    graph_def.node.extend([b_constant])
    b_check_node = self.create_node_def("CheckNumerics", b_check_name,
                                        [b_constant_name])
    graph_def.node.extend([b_check_node])
    b_identity_node = self.create_node_def("Identity", b_identity_name,
                                           [b_constant_name,
                                            "^" + b_check_name])
    graph_def.node.extend([b_identity_node])
    add_node = self.create_node_def("Add", add_name,
                                    [a_identity_name,
                                     b_identity_name])
    self.set_attr_dtype(add_node, "T", tf.float32)
    graph_def.node.extend([add_node])

    expected_output = tf.GraphDef()
    a_constant = self.create_constant_node_def(a_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    expected_output.node.extend([a_constant])
    b_constant = self.create_constant_node_def(b_constant_name,
                                               value=1,
                                               dtype=tf.float32,
                                               shape=[])
    expected_output.node.extend([b_constant])
    add_node = self.create_node_def("Add", add_name,
                                    [a_constant_name,
                                     b_constant_name])
    self.set_attr_dtype(add_node, "T", tf.float32)
    expected_output.node.extend([add_node])

    output = graph_util.remove_training_nodes(graph_def)
    self.assertProtoEquals(expected_output, output) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:60,代码来源:graph_util_test.py

示例11: rewrite

# 需要导入模块: from tensorflow.python.framework import graph_util [as 别名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 别名]
def rewrite(self, output_node_names):
    """Triggers rewriting of the float graph.

    Args:
      output_node_names: A list of names of the nodes that produce the final
        results.

    Returns:
      A quantized version of the float graph.
    """
    self.output_graph = tf.GraphDef()
    output_nodes = [self.nodes_map[output_node_name]
                    for output_node_name in output_node_names]
    if self.mode == "round":
      self.already_visited = {}
      for output_node in output_nodes:
        self.round_nodes_recursively(output_node)
    elif self.mode == "quantize":
      self.already_visited = {}
      self.already_quantized = {}
      for output_node in output_nodes:
        self.quantize_nodes_recursively(output_node)
    elif self.mode == "eightbit":
      self.set_input_graph(graph_util.remove_training_nodes(self.input_graph))
      output_nodes = [self.nodes_map[output_node_name]
                      for output_node_name in output_node_names]

      self.state = EightbitizeRecursionState(already_visited={},
                                             output_node_stack=[],
                                             merged_with_fake_quant={})
      for output_node in output_nodes:
        self.eightbitize_nodes_recursively(output_node)
      self.state = None
      if self.input_range:
        self.add_output_graph_node(create_constant_node(
            "quantized_input_min_value", self.input_range[0], tf.float32, []))
        self.add_output_graph_node(create_constant_node(
            "quantized_input_max_value", self.input_range[1], tf.float32, []))
      if self.fallback_quantization_range:
        self.add_output_graph_node(create_constant_node(
            "fallback_quantization_min_value",
            self.fallback_quantization_range[0], tf.float32, []))
        self.add_output_graph_node(create_constant_node(
            "fallback_quantization_max_value",
            self.fallback_quantization_range[1], tf.float32, []))
      if FLAGS.strip_redundant_quantization:
        self.output_graph = self.remove_redundant_quantization(
            self.output_graph)
        self.remove_dead_nodes(output_node_names)
      self.apply_final_node_renames()
    elif self.mode == "weights":
      self.output_graph = self.quantize_weights(self.input_graph,
                                                b"MIN_COMBINED")
      self.remove_dead_nodes(output_node_names)
    elif self.mode == "weights_rounded":
      self.output_graph = self.quantize_weights(self.input_graph, self.mode)
      self.remove_dead_nodes(output_node_names)
    else:
      print("Bad mode - " + self.mode + ".")
    return self.output_graph 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:62,代码来源:quantize_graph.py


注:本文中的tensorflow.python.framework.graph_util.remove_training_nodes方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。