本文整理匯總了Python中tensorflow.python.framework.graph_util.extract_sub_graph方法的典型用法代碼示例。如果您正苦於以下問題:Python graph_util.extract_sub_graph方法的具體用法?Python graph_util.extract_sub_graph怎麽用?Python graph_util.extract_sub_graph使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.python.framework.graph_util
的用法示例。
在下文中一共展示了graph_util.extract_sub_graph方法的7個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: strip_unused
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 別名]
def strip_unused(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a GraphDef.
Args:
input_graph_def: A graph with nodes we want to prune.
input_node_names: A list of the nodes we use as inputs.
output_node_names: A list of the output nodes.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.
Returns:
A GraphDef with all unnecessary ops removed.
"""
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
inputs_replaced_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
if isinstance(placeholder_type_enum, list):
input_node_index = input_node_names.index(node.name)
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum[
input_node_index]))
else:
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum))
if "_output_shapes" in node.attr:
placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
"_output_shapes"])
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def
示例2: strip_unused
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 別名]
def strip_unused(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a GraphDef.
Args:
input_graph_def: A graph with nodes we want to prune.
input_node_names: A list of the nodes we use as inputs.
output_node_names: A list of the output nodes.
placeholder_type_enum: The AttrValue enum for the placeholder data type.
Returns:
A GraphDef with all unnecessary ops removed.
"""
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
inputs_replaced_graph_def = tf.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
placeholder_node = tf.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
type=placeholder_type_enum))
if "_output_shapes" in node.attr:
placeholder_node.attr["_output_shapes"].CopyFrom(
node.attr["_output_shapes"])
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def
示例3: testExtractSubGraph
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 別名]
def testExtractSubGraph(self):
graph_def = tf.GraphDef()
n1 = graph_def.node.add()
n1.name = "n1"
n1.input.extend(["n5"])
n2 = graph_def.node.add()
n2.name = "n2"
# Take the first output of the n1 node as the input.
n2.input.extend(["n1:0"])
n3 = graph_def.node.add()
n3.name = "n3"
# Add a control input (which isn't really needed by the kernel, but
# rather to enforce execution order between nodes).
n3.input.extend(["^n2"])
n4 = graph_def.node.add()
n4.name = "n4"
# It is fine to have a loops in the graph as well.
n5 = graph_def.node.add()
n5.name = "n5"
n5.input.extend(["n1"])
sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
self.assertEqual("n1", sub_graph.node[0].name)
self.assertEqual("n2", sub_graph.node[1].name)
self.assertEqual("n3", sub_graph.node[2].name)
self.assertEqual("n5", sub_graph.node[3].name)
示例4: remove_dead_nodes
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 別名]
def remove_dead_nodes(self, output_names):
"""Removes nodes that are no longer needed for inference from the graph."""
old_output_graph = self.output_graph
self.output_graph = graph_util.extract_sub_graph(old_output_graph,
output_names)
示例5: strip_meta_graph
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 別名]
def strip_meta_graph(meta_graph_def, node_names, var_names):
node_names = node_names[:]
collections = meta_graph_def.collection_def
# Look for matching variable names and initializers and keep them too.
var_def = variable_pb2.VariableDef()
for var_col_name in ["variables", "trainable_variables"]:
var_def_bs = collections[var_col_name].bytes_list.value
for var_def_b in var_def_bs:
var_def.ParseFromString(var_def_b)
if var_def.variable_name not in var_names:
# TODO(adamb) Should remove variable from collection.
continue
node_names.append(var_def.initializer_name)
wc_def = control_flow_pb2.WhileContextDef()
wc_values = collections["while_context"].bytes_list.value
for wc_ix in range(len(wc_values) - 1, -1, -1):
wc_bytes = wc_values[wc_ix]
wc_def.ParseFromString(wc_bytes)
unused = True
wc_pivot_name = wc_def.pivot_name
for name in node_names:
if name.startswith(wc_pivot_name):
unused = False
break
if unused:
del wc_values[wc_ix]
graph_def = meta_graph_def.graph_def
eprint("only keeping", node_names, "from", [n.name for n in graph_def.node])
graph_def = graph_util.extract_sub_graph(graph_def, node_names)
meta_graph_def.graph_def.CopyFrom(graph_def)
示例6: strip_unused
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 別名]
def strip_unused(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a GraphDef.
Args:
input_graph_def: A graph with nodes we want to prune.
input_node_names: A list of the nodes we use as inputs.
output_node_names: A list of the output nodes.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.
Returns:
A `GraphDef` with all unnecessary ops removed.
Raises:
ValueError: If any element in `input_node_names` refers to a tensor instead
of an operation.
KeyError: If any element in `input_node_names` is not found in the graph.
"""
for name in input_node_names:
if ":" in name:
raise ValueError("Name '%s' appears to refer to a Tensor, "
"not a Operation." % name)
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
not_found = {name for name in input_node_names}
inputs_replaced_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
not_found.remove(node.name)
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
if isinstance(placeholder_type_enum, list):
input_node_index = input_node_names.index(node.name)
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum[
input_node_index]))
else:
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum))
if "_output_shapes" in node.attr:
placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
"_output_shapes"])
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
if not_found:
raise KeyError("The following input nodes were not found: %s\n" % not_found)
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def
示例7: strip_unused
# 需要導入模塊: from tensorflow.python.framework import graph_util [as 別名]
# 或者: from tensorflow.python.framework.graph_util import extract_sub_graph [as 別名]
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
output_node_names, placeholder_type_enum):
"""Removes unused nodes from a graph."""
if not tf.gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
input_graph_def = tf.GraphDef()
mode = "rb" if input_binary else "r"
with tf.gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
input_node_names_list = input_node_names.split(",")
inputs_replaced_graph_def = tf.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names_list:
placeholder_node = tf.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
type=placeholder_type_enum))
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))