本文整理匯總了Python中tensorflow.python.framework.graph_util.remove_training_nodes方法的典型用法代碼示例。如果您正苦於以下問題:Python graph_util.remove_training_nodes方法的具體用法?Python graph_util.remove_training_nodes怎麽用?Python graph_util.remove_training_nodes使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.python.framework.graph_util
的用法示例。
在下文中一共展示了graph_util.remove_training_nodes方法的11個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: export_cnn
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def export_cnn() -> None:
input = tf.placeholder(tf.float32, shape=(1, 1, 3, 3))
filter = tf.constant(np.ones((3, 3, 1, 1)), dtype=tf.float32)
x = tf.nn.conv2d(input, filter, (1, 1, 1, 1), "SAME", data_format="NCHW")
x = tf.nn.sigmoid(x)
x = tf.nn.relu(x)
pred_node_names = ["output"]
tf.identity(x, name=pred_node_names[0])
with tf.Session() as sess:
constant_graph = graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(), pred_node_names
)
frozen = graph_util.remove_training_nodes(constant_graph)
output = "cnn.pb"
graph_io.write_graph(frozen, ".", output, as_text=False)
示例2: export
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def export(x: tf.Tensor, filename: str, sess=None):
should_close = False
if sess is None:
should_close = True
sess = tf.Session()
pred_node_names = ["output"]
tf.identity(x, name=pred_node_names[0])
graph = graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(), pred_node_names
)
graph = graph_util.remove_training_nodes(graph)
path = graph_io.write_graph(graph, ".", filename, as_text=False)
if should_close:
sess.close()
return path
示例3: _load_saved_model
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def _load_saved_model(self):
"""Load the tensorflow saved model."""
try:
from tensorflow.python.tools import freeze_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import graph_util
from tensorflow.core.framework import graph_pb2
except ImportError:
raise ImportError(
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.")
saved_model_dir = self._model_dir
output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb")
input_saved_model_dir = saved_model_dir
output_node_names = self._get_output_names()
input_binary = False
input_saver_def_path = False
restore_op_name = None
filename_tensor_name = None
clear_devices = True
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = ",".join(self._get_tag_set())
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_graph_filename, clear_devices, "", "", "",
input_meta_graph, input_saved_model_dir,
saved_model_tags)
with ops.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(output_graph_filename, "rb") as f:
output_graph_def.ParseFromString(f.read())
output_graph_def = graph_util.remove_training_nodes(output_graph_def,
protected_nodes=self._outputs)
return output_graph_def
示例4: optimize_for_inference
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Applies a series of inference optimizations on the input graph.
Args:
input_graph_def: A GraphDef containing a training model.
input_node_names: A list of names of the nodes that are fed inputs during
inference.
output_node_names: A list of names of the nodes that produce the final
results.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.
Returns:
An optimized version of the input graph.
"""
ensure_graph_is_valid(input_graph_def)
optimized_graph_def = input_graph_def
optimized_graph_def = strip_unused_lib.strip_unused(optimized_graph_def,
input_node_names,
output_node_names,
placeholder_type_enum)
optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def)
optimized_graph_def = fold_batch_norms(optimized_graph_def)
optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
output_node_names)
ensure_graph_is_valid(optimized_graph_def)
return optimized_graph_def
示例5: optimize_for_inference
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def optimize_for_inference(input_graph_def, input_node_names,
output_node_names, placeholder_type_enum):
"""Applies a series of inference optimizations on the input graph.
Args:
input_graph_def: A GraphDef containing a training model.
input_node_names: A list of names of the nodes that are fed inputs during
inference.
output_node_names: A list of names of the nodes that produce the final
results.
placeholder_type_enum: Data type of the placeholders used for inputs.
Returns:
An optimized version of the input graph.
"""
ensure_graph_is_valid(input_graph_def)
optimized_graph_def = input_graph_def
optimized_graph_def = strip_unused_lib.strip_unused(optimized_graph_def,
input_node_names,
output_node_names,
placeholder_type_enum)
optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def)
optimized_graph_def = fold_batch_norms(optimized_graph_def)
optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
output_node_names)
ensure_graph_is_valid(optimized_graph_def)
return optimized_graph_def
示例6: export_to_pb
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def export_to_pb(sess, x, filename):
pred_names = ["output"]
tf.identity(x, name=pred_names[0])
graph = graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(), pred_names
)
graph = graph_util.remove_training_nodes(graph)
path = graph_io.write_graph(graph, ".", filename, as_text=False)
print("saved the frozen graph (ready for inference) at: ", path)
示例7: export_to_pb
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def export_to_pb(sess, x, filename):
pred_names = ["output"]
tf.identity(x, name=pred_names[0])
graph = graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(), pred_names
)
graph = graph_util.remove_training_nodes(graph)
path = graph_io.write_graph(graph, ".", filename, as_text=False)
print("saved the frozen graph (ready for inference) at: ", filename)
return path
示例8: optimize_for_inference
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Applies a series of inference optimizations on the input graph.
Args:
input_graph_def: A GraphDef containing a training model.
input_node_names: A list of names of the nodes that are fed inputs during
inference.
output_node_names: A list of names of the nodes that produce the final
results.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.
Returns:
An optimized version of the input graph.
"""
ensure_graph_is_valid(input_graph_def)
optimized_graph_def = input_graph_def
optimized_graph_def = strip_unused_lib.strip_unused(
optimized_graph_def, input_node_names, output_node_names,
placeholder_type_enum)
optimized_graph_def = graph_util.remove_training_nodes(
optimized_graph_def, output_node_names)
optimized_graph_def = fold_batch_norms(optimized_graph_def)
optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
output_node_names)
ensure_graph_is_valid(optimized_graph_def)
return optimized_graph_def
開發者ID:PacktPublishing,項目名稱:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代碼行數:30,代碼來源:optimize_for_inference_lib.py
示例9: __init__
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def __init__(self, model, session = None):
"""
This constructor takes a reference to a TensorFlow Operation or Tensor or Keras model and then applies the two TensorFlow functions
graph_util.convert_variables_to_constants and graph_util.remove_training_nodes to cleanse the graph of any nodes that are linked to training. This leaves us with
the nodes you need for inference.
In the resulting graph there should only be tf.Operations left that have one of the following types [Const, MatMul, Add, BiasAdd, Conv2D, Reshape, MaxPool, AveragePool, Placeholder, Relu, Sigmoid, Tanh]
If the input should be a Keras model we will ignore operations with type Pack, Shape, StridedSlice, and Prod such that the Flatten layer can be used.
Arguments
---------
model : tensorflow.Tensor or tensorflow.Operation or tensorflow.python.keras.engine.sequential.Sequential or keras.engine.sequential.Sequential
if tensorflow.Tensor: model.op will be treated as the output node of the TensorFlow model. Make sure that the graph only contains supported operations after applying
graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [model.op.name] as output_node_names
if tensorflow.Operation: model will be treated as the output of the TensorFlow model. Make sure that the graph only contains supported operations after applying
graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [model.op.name] as output_node_names
if tensorflow.python.keras.engine.sequential.Sequential: x = model.layers[-1].output.op.inputs[0].op will be treated as the output node of the Keras model. Make sure that the graph only
contains supported operations after applying graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [x.name] as
output_node_names
if keras.engine.sequential.Sequential: x = model.layers[-1].output.op.inputs[0].op will be treated as the output node of the Keras model. Make sure that the graph only
contains supported operations after applying graph_util.convert_variables_to_constants and graph_util.remove_training_nodes with [x.name] as
output_node_names
session : tf.Session
session which contains the information about the trained variables. If None the code will take the Session from tf.get_default_session(). If you pass a keras model you don't have to
provide a session, this function will automatically get it.
"""
output_names = None
if issubclass(model.__class__, tf.Tensor):
output_names = [model.op.name]
elif issubclass(model.__class__, tf.Operation):
output_names = [model.name]
elif issubclass(model.__class__, Sequential):
session = tf.keras.backend.get_session()
output_names = [model.layers[-1].output.op.inputs[0].op.name]
model = model.layers[-1].output.op
elif issubclass(model.__class__, onnx.ModelProto):
assert 0, 'not tensorflow model'
else:
import keras
if issubclass(model.__class__, keras.engine.sequential.Sequential):
session = keras.backend.get_session()
output_names = [model.layers[-1].output.op.inputs[0].op.name]
model = model.layers[-1].output.op
else:
assert 0, "ERAN can't recognize this input"
if session is None:
session = tf.get_default_session()
tmp = graph_util.convert_variables_to_constants(session, model.graph.as_graph_def(), output_names)
self.graph_def = graph_util.remove_training_nodes(tmp)
示例10: testRemoveTrainingNodes
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def testRemoveTrainingNodes(self):
a_constant_name = "a_constant"
b_constant_name = "b_constant"
a_check_name = "a_check"
b_check_name = "b_check"
a_identity_name = "a_identity"
b_identity_name = "b_identity"
add_name = "add"
graph_def = tf.GraphDef()
a_constant = self.create_constant_node_def(a_constant_name,
value=1,
dtype=tf.float32,
shape=[])
graph_def.node.extend([a_constant])
a_check_node = self.create_node_def("CheckNumerics", a_check_name,
[a_constant_name])
graph_def.node.extend([a_check_node])
a_identity_node = self.create_node_def("Identity", a_identity_name,
[a_constant_name,
"^" + a_check_name])
graph_def.node.extend([a_identity_node])
b_constant = self.create_constant_node_def(b_constant_name,
value=1,
dtype=tf.float32,
shape=[])
graph_def.node.extend([b_constant])
b_check_node = self.create_node_def("CheckNumerics", b_check_name,
[b_constant_name])
graph_def.node.extend([b_check_node])
b_identity_node = self.create_node_def("Identity", b_identity_name,
[b_constant_name,
"^" + b_check_name])
graph_def.node.extend([b_identity_node])
add_node = self.create_node_def("Add", add_name,
[a_identity_name,
b_identity_name])
self.set_attr_dtype(add_node, "T", tf.float32)
graph_def.node.extend([add_node])
expected_output = tf.GraphDef()
a_constant = self.create_constant_node_def(a_constant_name,
value=1,
dtype=tf.float32,
shape=[])
expected_output.node.extend([a_constant])
b_constant = self.create_constant_node_def(b_constant_name,
value=1,
dtype=tf.float32,
shape=[])
expected_output.node.extend([b_constant])
add_node = self.create_node_def("Add", add_name,
[a_constant_name,
b_constant_name])
self.set_attr_dtype(add_node, "T", tf.float32)
expected_output.node.extend([add_node])
output = graph_util.remove_training_nodes(graph_def)
self.assertProtoEquals(expected_output, output)
示例11: rewrite
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import remove_training_nodes [as 別名]
def rewrite(self, output_node_names):
"""Triggers rewriting of the float graph.
Args:
output_node_names: A list of names of the nodes that produce the final
results.
Returns:
A quantized version of the float graph.
"""
self.output_graph = tf.GraphDef()
output_nodes = [self.nodes_map[output_node_name]
for output_node_name in output_node_names]
if self.mode == "round":
self.already_visited = {}
for output_node in output_nodes:
self.round_nodes_recursively(output_node)
elif self.mode == "quantize":
self.already_visited = {}
self.already_quantized = {}
for output_node in output_nodes:
self.quantize_nodes_recursively(output_node)
elif self.mode == "eightbit":
self.set_input_graph(graph_util.remove_training_nodes(self.input_graph))
output_nodes = [self.nodes_map[output_node_name]
for output_node_name in output_node_names]
self.state = EightbitizeRecursionState(already_visited={},
output_node_stack=[],
merged_with_fake_quant={})
for output_node in output_nodes:
self.eightbitize_nodes_recursively(output_node)
self.state = None
if self.input_range:
self.add_output_graph_node(create_constant_node(
"quantized_input_min_value", self.input_range[0], tf.float32, []))
self.add_output_graph_node(create_constant_node(
"quantized_input_max_value", self.input_range[1], tf.float32, []))
if self.fallback_quantization_range:
self.add_output_graph_node(create_constant_node(
"fallback_quantization_min_value",
self.fallback_quantization_range[0], tf.float32, []))
self.add_output_graph_node(create_constant_node(
"fallback_quantization_max_value",
self.fallback_quantization_range[1], tf.float32, []))
if FLAGS.strip_redundant_quantization:
self.output_graph = self.remove_redundant_quantization(
self.output_graph)
self.remove_dead_nodes(output_node_names)
self.apply_final_node_renames()
elif self.mode == "weights":
self.output_graph = self.quantize_weights(self.input_graph,
b"MIN_COMBINED")
self.remove_dead_nodes(output_node_names)
elif self.mode == "weights_rounded":
self.output_graph = self.quantize_weights(self.input_graph, self.mode)
self.remove_dead_nodes(output_node_names)
else:
print("Bad mode - " + self.mode + ".")
return self.output_graph