本文整理汇总了Python中tensorflow.python.framework.graph_util.extract_sub_graph函数的典型用法代码示例。如果您正苦于以下问题:Python extract_sub_graph函数的具体用法?Python extract_sub_graph怎么用?Python extract_sub_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了extract_sub_graph函数的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: strip_unused
def strip_unused(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a GraphDef.
Args:
input_graph_def: A graph with nodes we want to prune.
input_node_names: A list of the nodes we use as inputs.
output_node_names: A list of the output nodes.
placeholder_type_enum: The AttrValue enum for the placeholder data type.
Returns:
A GraphDef with all unnecessary ops removed.
"""
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
inputs_replaced_graph_def = tf.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
placeholder_node = tf.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
type=placeholder_type_enum))
if "_output_shapes" in node.attr:
placeholder_node.attr["_output_shapes"].CopyFrom(
node.attr["_output_shapes"])
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def
示例2: testExtractSubGraph
def testExtractSubGraph(self):
graph_def = graph_pb2.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
n1.input.extend(["n5"])
n2 = graph_def.node.add()
n2.name = "n2"
# Take the first output of the n1 node as the input.
n2.input.extend(["n1:0"])
n3 = graph_def.node.add()
n3.name = "n3"
# Add a control input (which isn't really needed by the kernel, but
# rather to enforce execution order between nodes).
n3.input.extend(["^n2"])
n4 = graph_def.node.add()
n4.name = "n4"
# It is fine to have a loops in the graph as well.
n5 = graph_def.node.add()
n5.name = "n5"
n5.input.extend(["n1"])
sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
self.assertEqual("n1", sub_graph.node[0].name)
self.assertEqual("n2", sub_graph.node[1].name)
self.assertEqual("n3", sub_graph.node[2].name)
self.assertEqual("n5", sub_graph.node[3].name)
示例3: test_remove_unneeded_nodes
def test_remove_unneeded_nodes(self):
a_constant_name = "a_constant"
b_constant_name = "b_constant"
a_check_name = "a_check"
b_check_name = "b_check"
a_identity_name = "a_identity"
b_identity_name = "b_identity"
add_name = "add"
graph_def = tf.GraphDef()
a_constant = quantize_graph.create_constant_node(a_constant_name,
value=1,
dtype=tf.float32,
shape=[])
graph_def.node.extend([a_constant])
a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
[a_constant_name])
graph_def.node.extend([a_check_node])
a_identity_node = quantize_graph.create_node("Identity", a_identity_name,
[a_constant_name,
"^" + a_check_name])
graph_def.node.extend([a_identity_node])
b_constant = quantize_graph.create_constant_node(b_constant_name,
value=1,
dtype=tf.float32,
shape=[])
graph_def.node.extend([b_constant])
b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
[b_constant_name])
graph_def.node.extend([b_check_node])
b_identity_node = quantize_graph.create_node("Identity", b_identity_name,
[b_constant_name,
"^" + b_check_name])
graph_def.node.extend([b_identity_node])
add_node = quantize_graph.create_node("Add", add_name,
[a_identity_name,
b_identity_name])
quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
graph_def.node.extend([add_node])
expected_output = tf.GraphDef()
a_constant = quantize_graph.create_constant_node(a_constant_name,
value=1,
dtype=tf.float32,
shape=[])
expected_output.node.extend([a_constant])
b_constant = quantize_graph.create_constant_node(b_constant_name,
value=1,
dtype=tf.float32,
shape=[])
expected_output.node.extend([b_constant])
add_node = quantize_graph.create_node("Add", add_name,
[a_constant_name,
b_constant_name])
quantize_graph.set_attr_dtype(add_node, "T", tf.float32)
expected_output.node.extend([add_node])
rewriter = quantize_graph.GraphRewriter(graph_def, [add_name])
output = rewriter.remove_unneeded_nodes(graph_def)
stripped_output = graph_util.extract_sub_graph(output, [add_name])
self.assertProtoEquals(expected_output, stripped_output)
示例4: test_keep_control_edges
def test_keep_control_edges(self):
no_op_name = "no_op"
a_constant_name = "a_constant"
b_constant_name = "b_constant"
a_check_name = "a_check"
b_check_name = "b_check"
a_identity_name = "a_identity"
b_identity_name = "b_identity"
add_name = "add"
graph_def = graph_pb2.GraphDef()
no_op = quantize_graph.create_node("NoOp", no_op_name, [])
graph_def.node.extend([no_op])
a_constant = quantize_graph.create_constant_node(
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
graph_def.node.extend([a_constant])
a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
[a_constant_name])
graph_def.node.extend([a_check_node])
a_identity_node = quantize_graph.create_node(
"Identity", a_identity_name,
[a_constant_name, "^" + a_check_name, "^" + no_op_name])
graph_def.node.extend([a_identity_node])
b_constant = quantize_graph.create_constant_node(
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
graph_def.node.extend([b_constant])
b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
[b_constant_name])
graph_def.node.extend([b_check_node])
b_identity_node = quantize_graph.create_node(
"Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
graph_def.node.extend([b_identity_node])
add_node = quantize_graph.create_node("Add", add_name,
[a_identity_name, b_identity_name])
quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
graph_def.node.extend([add_node])
expected_output = graph_pb2.GraphDef()
no_op = quantize_graph.create_node("NoOp", no_op_name, [])
expected_output.node.extend([no_op])
a_constant = quantize_graph.create_constant_node(
a_constant_name, value=1, dtype=dtypes.float32, shape=[])
expected_output.node.extend([a_constant])
a_identity_node = quantize_graph.create_node(
"Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
expected_output.node.extend([a_identity_node])
b_constant = quantize_graph.create_constant_node(
b_constant_name, value=1, dtype=dtypes.float32, shape=[])
expected_output.node.extend([b_constant])
add_node = quantize_graph.create_node("Add", add_name,
[a_identity_name, b_constant_name])
quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
expected_output.node.extend([add_node])
expected_output.versions.CopyFrom(graph_def.versions)
expected_output.library.CopyFrom(graph_def.library)
output = graph_util.remove_training_nodes(graph_def)
stripped_output = graph_util.extract_sub_graph(output, [add_name])
self.assertProtoEquals(expected_output, stripped_output)
示例5: strip_unused
def strip_unused(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a GraphDef.
Args:
input_graph_def: A graph with nodes we want to prune.
input_node_names: A list of the nodes we use as inputs.
output_node_names: A list of the output nodes.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.
Returns:
A `GraphDef` with all unnecessary ops removed.
Raises:
ValueError: If any element in `input_node_names` refers to a tensor instead
of an operation.
KeyError: If any element in `input_node_names` is not found in the graph.
"""
for name in input_node_names:
if ":" in name:
raise ValueError("Name '%s' appears to refer to a Tensor, "
"not a Operation." % name)
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
not_found = {name for name in input_node_names}
inputs_replaced_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
not_found.remove(node.name)
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
if isinstance(placeholder_type_enum, list):
input_node_index = input_node_names.index(node.name)
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum[
input_node_index]))
else:
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum))
if "_output_shapes" in node.attr:
placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
"_output_shapes"])
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
if not_found:
raise KeyError("The following input nodes were not found: %s\n" % not_found)
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def
示例6: __init__
def __init__(self, meta_file, checkpoint_file, frozen_file, dest_nodes = None):
super(TensorflowParser, self).__init__()
# load model files into TensorFlow graph
if meta_file:
model = TensorflowParser._load_meta(meta_file)
if checkpoint_file:
self.ckpt_data = TensorflowParser._load_weights(checkpoint_file)
self.weight_loaded = True
if dest_nodes != None:
from tensorflow.python.framework.graph_util import extract_sub_graph
model = extract_sub_graph(model, dest_nodes.split(','))
# Build network graph
self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()
示例7: __init__
def __init__(self, input_args, dest_nodes = None):
super(TensorflowParser, self).__init__()
# load model files into Keras graph
from six import string_types as _string_types
if isinstance(input_args, _string_types):
model = TensorflowParser._load_meta(input_args)
elif isinstance(input_args, tuple):
model = TensorflowParser._load_meta(input_args[0])
self.ckpt_data = TensorflowParser._load_weights(input_args[1])
self.weight_loaded = True
if dest_nodes != None:
from tensorflow.python.framework.graph_util import extract_sub_graph
model = extract_sub_graph(model, dest_nodes.split(','))
# Build network graph
self.tf_graph = TensorflowGraph(model)
self.tf_graph.build()
示例8: strip_unused
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
output_node_names, placeholder_type_enum):
"""Removes unused nodes from a graph."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
input_node_names_list = input_node_names.split(",")
inputs_replaced_graph_def = tf.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names_list:
placeholder_node = tf.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
type=placeholder_type_enum))
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
示例9: strip_pruning_vars_fn
def strip_pruning_vars_fn(input_graph_def, output_node_names):
"""Removes mask variable from the graph.
Replaces the masked_weight tensor with element-wise multiplication of mask
and the corresponding weight variable.
Args:
input_graph_def: A GraphDef in which the variables have been converted to
constants. This is typically the output of
tf.graph_util.convert_variables_to_constant()
output_node_names: List of name strings for the result nodes of the graph
Returns:
A GraphDef in which pruning-related variables have been removed
"""
masked_weights_dict = _get_masked_weights(input_graph_def)
pruned_graph_def = graph_pb2.GraphDef()
# Replace masked_weight with a const op containing the
# result of tf.multiply(mask,weight)
for node in input_graph_def.node:
output_node = node_def_pb2.NodeDef()
if 'masked_weight' in node.name:
output_node.op = 'Const'
output_node.name = node.name
dtype = node.attr['T']
data = masked_weights_dict[node.name]
output_node.attr['dtype'].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(
attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data)))
else:
output_node.CopyFrom(node)
pruned_graph_def.node.extend([output_node])
# Remove stranded nodes: mask and weights
return graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
示例10: test_remove_redundant_quantization
def test_remove_redundant_quantization(self):
a_constant_name = "a_constant"
a_constant_min_name = "a_constant_min"
a_constant_max_name = "a_constant_max"
a_dequantize_name = "a_dequantize"
a_quantize_name = "a_quantize"
b_constant_name = "b_constant"
b_constant_min_name = "b_constant_min"
b_constant_max_name = "b_constant_max"
b_dequantize_name = "b_dequantize"
b_quantize_name = "b_quantize"
mat_mul_name = "mat_mul"
graph_def = graph_pb2.GraphDef()
a_constant = quantize_graph.create_constant_node(
a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
graph_def.node.extend([a_constant])
a_constant_min = quantize_graph.create_constant_node(
a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
graph_def.node.extend([a_constant_min])
a_constant_max = quantize_graph.create_constant_node(
a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
graph_def.node.extend([a_constant_max])
a_dequantize_node = quantize_graph.create_node(
"Dequantize", a_dequantize_name,
[a_constant_name, a_constant_min_name, a_constant_max_name])
quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8)
graph_def.node.extend([a_dequantize_node])
a_quantize_node = quantize_graph.create_node(
"QuantizeV2", a_quantize_name,
[a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"])
quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8)
graph_def.node.extend([a_quantize_node])
b_constant = quantize_graph.create_constant_node(
b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
graph_def.node.extend([b_constant])
b_constant_min = quantize_graph.create_constant_node(
b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
graph_def.node.extend([b_constant_min])
b_constant_max = quantize_graph.create_constant_node(
b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
graph_def.node.extend([b_constant_max])
b_dequantize_node = quantize_graph.create_node(
"Dequantize", b_dequantize_name,
[b_constant_name, b_constant_min_name, b_constant_max_name])
quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8)
graph_def.node.extend([b_dequantize_node])
b_quantize_node = quantize_graph.create_node(
"QuantizeV2", b_quantize_name,
[b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"])
quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8)
graph_def.node.extend([b_quantize_node])
mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
a_quantize_name, b_quantize_name, a_quantize_name + ":1",
a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2"
])
quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
graph_def.node.extend([mat_mul_node])
expected_output = graph_pb2.GraphDef()
a_constant = quantize_graph.create_constant_node(
a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
expected_output.node.extend([a_constant])
a_constant_min = quantize_graph.create_constant_node(
a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
expected_output.node.extend([a_constant_min])
a_constant_max = quantize_graph.create_constant_node(
a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
expected_output.node.extend([a_constant_max])
b_constant = quantize_graph.create_constant_node(
b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
expected_output.node.extend([b_constant])
b_constant_min = quantize_graph.create_constant_node(
b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
expected_output.node.extend([b_constant_min])
b_constant_max = quantize_graph.create_constant_node(
b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
expected_output.node.extend([b_constant_max])
mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
a_constant_name, b_constant_name, a_constant_min_name,
a_constant_max_name, b_constant_min_name, b_constant_max_name
])
quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
expected_output.node.extend([mat_mul_node])
expected_output.versions.CopyFrom(graph_def.versions)
expected_output.library.CopyFrom(graph_def.library)
rewriter = quantize_graph.GraphRewriter(
graph_def, [mat_mul_name], quantized_input_range=None)
output = rewriter.remove_redundant_quantization(graph_def)
stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
self.assertProtoEquals(expected_output, stripped_output)
示例11: testExtractSubGraphWithInvalidDestNodes
def testExtractSubGraphWithInvalidDestNodes(self):
graph_def = graph_pb2.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
with self.assertRaisesRegexp(TypeError, "must be a list"):
graph_util.extract_sub_graph(graph_def, "n1")
示例12: remove_dead_nodes
def remove_dead_nodes(self, output_names):
"""Removes nodes that are no longer needed for inference from the graph."""
old_output_graph = self.output_graph
self.output_graph = graph_util.extract_sub_graph(old_output_graph,
output_names)