本文整理汇总了Python中tensorflow.python.training.saver.export_meta_graph函数的典型用法代码示例。如果您正苦于以下问题:Python export_meta_graph函数的具体用法?Python export_meta_graph怎么用?Python export_meta_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了export_meta_graph函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testNoVariables
def testNoVariables(self):
test_dir = _TestDir("no_variables")
filename = os.path.join(test_dir, "metafile")
input_feed_value = -10 # Arbitrary input value for feed_dict.
orig_graph = tf.Graph()
with self.test_session(graph=orig_graph) as sess:
# Create a minimal graph with zero variables.
input_tensor = tf.placeholder(tf.float32, shape=[], name="input")
offset = tf.constant(42, dtype=tf.float32, name="offset")
output_tensor = tf.add(input_tensor, offset, name="add_offset")
# Add input and output tensors to graph collections.
tf.add_to_collection("input_tensor", input_tensor)
tf.add_to_collection("output_tensor", output_tensor)
output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
self.assertEqual(output_value, 32)
# Generates MetaGraphDef.
#
# Note that this is calling the saver *module-level* export_meta_graph and
# not the Saver.export_meta_graph instance-level method.
meta_graph_def = saver_module.export_meta_graph(
filename=filename,
graph_def=tf.get_default_graph().as_graph_def(),
collection_list=["input_tensor", "output_tensor"],
saver_def=None,
)
# Create a clean graph and import the MetaGraphDef nodes.
new_graph = tf.Graph()
with self.test_session(graph=new_graph) as sess:
# Import the previously export meta graph.
saver_instance = saver_module.import_meta_graph(filename)
# The saver instance should be None since there are no graph variables
# to be restored in this case.
self.assertIsNone(saver_instance)
# Re-exports the current graph state for comparison to the original.
new_meta_graph_def = saver_module.export_meta_graph(filename + "_new")
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
# Ensures that we can still get a reference to our graph collections.
new_input_tensor = tf.get_collection("input_tensor")[0]
new_output_tensor = tf.get_collection("output_tensor")[0]
# Verifies that the new graph computes the same result as the original.
new_output_value = sess.run(
new_output_tensor, {new_input_tensor: input_feed_value})
self.assertEqual(new_output_value, output_value)
示例2: main
def main(_):
if FLAGS.metagraphdef:
with gfile.GFile(FLAGS.metagraphdef) as meta_file:
metagraph = meta_graph_pb2.MetaGraphDef()
metagraph.ParseFromString(meta_file.read())
else:
with gfile.GFile(FLAGS.graphdef) as graph_file:
graph_def = graph_pb2.GraphDef()
if FLAGS.graphdef.endswith(".pbtxt"):
text_format.Merge(graph_file.read(), graph_def)
else:
graph_def.ParseFromString(graph_file.read())
importer.import_graph_def(graph_def, name="")
graph = ops.get_default_graph()
fetch = graph.get_operation_by_name(FLAGS.fetch)
graph.add_to_collection("train_op", fetch)
metagraph = saver.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)
if FLAGS.rewriter_config is not None:
rewriter_config = rewriter_config_pb2.RewriterConfig()
text_format.Merge(FLAGS.rewriter_config, rewriter_config)
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
metagraph.graph_def.CopyFrom(optimized_graph)
report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
print(report)
示例3: get_metagraph
def get_metagraph():
"""Constructs and returns a MetaGraphDef from the input file."""
if FLAGS.metagraphdef:
with gfile.GFile(FLAGS.metagraphdef) as meta_file:
metagraph = meta_graph_pb2.MetaGraphDef()
if FLAGS.metagraphdef.endswith(".pbtxt"):
text_format.Merge(meta_file.read(), metagraph)
else:
metagraph.ParseFromString(meta_file.read())
if FLAGS.fetch is not None:
fetch_collection = meta_graph_pb2.CollectionDef()
for fetch in FLAGS.fetch.split(","):
fetch_collection.node_list.value.append(fetch)
metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
else:
with gfile.GFile(FLAGS.graphdef) as graph_file:
graph_def = graph_pb2.GraphDef()
if FLAGS.graphdef.endswith(".pbtxt"):
text_format.Merge(graph_file.read(), graph_def)
else:
graph_def.ParseFromString(graph_file.read())
importer.import_graph_def(graph_def, name="")
graph = ops.get_default_graph()
for fetch in FLAGS.fetch.split(","):
fetch_op = graph.get_operation_by_name(fetch)
graph.add_to_collection("train_op", fetch_op)
metagraph = saver.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)
return metagraph
示例4: testGradient
def testGradient(self):
if not test.is_gpu_available(cuda_only=True):
self.skipTest('GPU required')
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 200, 200, 3], seed=0)
y = conv_layers.conv2d(x, 32, [3, 3])
z = conv_layers.conv2d(y, 32, [3, 3])
optimizer = gradient_descent.GradientDescentOptimizer(1e-4)
loss = math_ops.reduce_mean(z)
train_op = optimizer.minimize(loss)
graph = ops.get_default_graph()
graph.add_to_collection('train_op', train_op)
meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def())
rewrite_options = rewriter_config_pb2.RewriterConfig(
optimize_tensor_layout=True)
optimized_graph = tf_optimizer.OptimizeGraph(rewrite_options, meta_graph)
found = 0
for node in optimized_graph.node:
if node.op in ['Conv2D', 'Conv2DBackpropFilter', 'Conv2DBackpropInput']:
found += 1
self.assertEqual(node.attr['data_format'].s, 'NCHW')
self.assertEqual(found, 5)
示例5: _run_inline_graph_optimization
def _run_inline_graph_optimization(func):
"""Apply function inline optimization to the graph.
Returns the GraphDef after Grappler's function inlining optimization is
applied. This optimization does not work on models with control flow.
Args:
func: ConcreteFunction.
Returns:
GraphDef
"""
meta_graph = export_meta_graph(
graph_def=func.graph.as_graph_def(), graph=func.graph)
# Add a collection 'train_op' so that Grappler knows the outputs.
fetch_collection = meta_graph_pb2.CollectionDef()
for array in func.inputs + func.outputs:
fetch_collection.node_list.value.append(array.name)
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
# Initialize RewriterConfig with everything disabled except function inlining.
config = config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
rewrite_options.optimizers.append("function")
return tf_optimizer.OptimizeGraph(config, meta_graph)
示例6: _ExportAndImportGraph
def _ExportAndImportGraph(self, graph):
"""Export and import graph into a new graph."""
meta_graph = saver_lib.export_meta_graph(
graph=graph, collection_list=graph.get_all_collection_keys())
graph_copy = ops.Graph()
with graph_copy.as_default():
_ = saver_lib.import_meta_graph(meta_graph)
return graph_copy
示例7: _CopyGraph
def _CopyGraph(self, graph):
"""Return a copy of graph."""
meta_graph = saver_lib.export_meta_graph(
graph=graph, collection_list=graph.get_all_collection_keys())
graph_copy = ops.Graph()
with graph_copy.as_default():
_ = saver_lib.import_meta_graph(meta_graph)
return graph_copy
示例8: testMetagraph
def testMetagraph(self):
with ops.Graph().as_default():
with variable_scope.variable_scope("foo", use_resource=True):
a = variable_scope.get_variable("a", initializer=10.0)
momentum.MomentumOptimizer(
learning_rate=0.001, momentum=0.1).minimize(
a,
colocate_gradients_with_ops=True,
global_step=training_util.get_or_create_global_step())
graph = ops.get_default_graph()
meta_graph_def = saver.export_meta_graph(graph=graph)
with ops.Graph().as_default():
saver.import_meta_graph(meta_graph_def, import_scope="")
meta_graph_two = saver.export_meta_graph(graph=graph)
self.assertEqual(meta_graph_def, meta_graph_two)
示例9: _convert_graph_def
def _convert_graph_def(self):
"""Convert the input GraphDef."""
graph = ops.Graph()
with graph.as_default():
importer.import_graph_def(self._input_graph_def, name="")
self._grappler_meta_graph_def = saver.export_meta_graph(
graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
self._add_nodes_blacklist()
self._run_conversion()
示例10: setUp
def setUp(self):
self.base_path = os.path.join(test.get_temp_dir(), "no_vars")
if not os.path.exists(self.base_path):
os.mkdir(self.base_path)
# Create a simple graph with a variable, then convert variables to
# constants and export the graph.
with ops.Graph().as_default() as g:
x = array_ops.placeholder(dtypes.float32, name="x")
w = variables.Variable(3.0)
y = math_ops.subtract(w * x, 7.0, name="y") # pylint: disable=unused-variable
ops.add_to_collection("meta", "this is meta")
with self.session(graph=g) as session:
variables.global_variables_initializer().run()
new_graph_def = graph_util.convert_variables_to_constants(
session, g.as_graph_def(), ["y"])
filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME)
saver.export_meta_graph(
filename, graph_def=new_graph_def, collection_list=["meta"])
示例11: _simple_metagraph
def _simple_metagraph(depthwise=False):
random_seed.set_random_seed(0)
x = variables.Variable(random_ops.truncated_normal([1, 200, 200, 3], seed=0))
conv = conv_layers.separable_conv2d if depthwise else conv_layers.conv2d
y = conv(x, 32, [3, 3])
z = conv(y, 32, [3, 3])
optimizer = gradient_descent.GradientDescentOptimizer(1e-4)
loss = math_ops.reduce_mean(z)
train_op = optimizer.minimize(loss)
graph = ops.get_default_graph()
graph.add_to_collection('train_op', train_op)
meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def())
return meta_graph
示例12: test_meta_graph_transform
def test_meta_graph_transform(self):
with ops.Graph().as_default():
with tf_session.Session(''):
a = array_ops.placeholder(dtypes.int64, [1], name='a')
b = array_ops.placeholder(dtypes.int64, [1], name='b')
c = array_ops.placeholder(dtypes.int64, [1], name='c')
_ = a * b
_ = b * c
base_meta_graph_def = saver.export_meta_graph()
with ops.Graph().as_default():
with tf_session.Session(''):
a = array_ops.placeholder(dtypes.int64, [1], name='a')
b = array_ops.placeholder(dtypes.int64, [1], name='b')
_ = a * b
meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
meta_info_def.tags.append('tag_ab')
expected_meta_graph_def = saver.export_meta_graph(
meta_info_def=meta_info_def)
# Graph rewriter clears versions field, so we expect that.
expected_meta_graph_def.graph_def.ClearField('versions')
# Graph rewriter adds an empty library field, so we expect that.
expected_meta_graph_def.graph_def.library.CopyFrom(
function_pb2.FunctionDefLibrary())
input_names = ['a', 'b']
output_names = ['mul:0']
transforms = ['strip_unused_nodes']
tags = ['tag_ab']
print('AAAAAA: {}'.format(base_meta_graph_def))
transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
base_meta_graph_def, input_names, output_names, transforms, tags)
self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
示例13: _convert_saved_model_v2
def _convert_saved_model_v2(self):
"""Convert the input SavedModel in 2.0 format."""
self._saved_model = load.load(self._input_saved_model_dir,
self._input_saved_model_tags)
func = self._saved_model.signatures[self._input_saved_model_signature_key]
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
self._grappler_meta_graph_def = saver.export_meta_graph(
graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
# Add a collection 'train_op' so that Grappler knows the outputs.
fetch_collection = meta_graph_pb2.CollectionDef()
for array in func.inputs + func.outputs:
fetch_collection.node_list.value.append(array.name)
self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
fetch_collection)
# Run TRT optimizer in Grappler to convert the graph.
self._run_conversion()
def _get_tensor(graph, tensors):
new_tensors = []
for tensor in tensors:
new_tensor = graph.get_tensor_by_name(tensor.name)
new_tensor.set_shape(tensor.shape)
new_tensors.append(new_tensor)
return new_tensors
# TODO(laigd): do we need to use different name e.g. "trt_func_graph"?
converted_graph = func_graph.FuncGraph(func.graph.name)
with converted_graph.as_default():
importer.import_graph_def(self._converted_graph_def, name="")
converted_graph.inputs = _get_tensor(converted_graph, func.graph.inputs)
converted_graph.outputs = _get_tensor(converted_graph, func.graph.outputs)
converted_graph.structured_outputs = func.graph.structured_outputs
converted_graph.structured_input_signature = (
func.graph.structured_input_signature)
# pylint: disable=protected-access
# TODO(laigd): should we set up the signature as well?
self._converted_func = function.ConcreteFunction(
converted_graph, attrs=None, signature=None)
self._converted_func.add_to_graph()
self._converted_func._arg_keywords = func._arg_keywords
self._converted_func._num_positional_args = func._num_positional_args
self._converted_func._captured_inputs = func._captured_inputs
self._converted_func.graph.variables = func.graph.variables
示例14: grappler_optimize
def grappler_optimize(graph, fetches=None, rewriter_config=None):
"""Tries to optimize the provided graph using grappler.
Args:
graph: A @{tf.Graph} instance containing the graph to optimize.
fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away).
Grappler uses the 'train_op' collection to look for fetches, so if not
provided this collection should be non-empty.
rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the
graph.
Returns:
A @{tf.GraphDef} containing the rewritten graph.
"""
if rewriter_config is None:
rewriter_config = rewriter_config_pb2.RewriterConfig()
if fetches is not None:
for fetch in fetches:
graph.add_to_collection('train_op', fetch)
metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
return tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
示例15: _convert_graph_def
def _convert_graph_def(self):
"""Convert the input GraphDef."""
graph = ops.Graph()
with graph.as_default():
importer.import_graph_def(self._input_graph_def, name="")
self._grappler_meta_graph_def = saver.export_meta_graph(
graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
if self._nodes_blacklist:
output_collection = meta_graph_pb2.CollectionDef()
output_list = output_collection.node_list.value
for i in self._nodes_blacklist:
if isinstance(i, ops.Tensor):
output_list.append(_to_bytes(i.name))
else:
output_list.append(_to_bytes(i))
# TODO(laigd): use another key as the self._nodes_blacklist are really
# not train_op.
self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
output_collection)
self._run_conversion()