本文整理汇总了Python中tensorflow.python.framework.importer.import_graph_def函数的典型用法代码示例。如果您正苦于以下问题:Python import_graph_def函数的具体用法?Python import_graph_def怎么用?Python import_graph_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了import_graph_def函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
def main(_):
if FLAGS.metagraphdef:
with gfile.GFile(FLAGS.metagraphdef) as meta_file:
metagraph = meta_graph_pb2.MetaGraphDef()
metagraph.ParseFromString(meta_file.read())
else:
with gfile.GFile(FLAGS.graphdef) as graph_file:
graph_def = graph_pb2.GraphDef()
if FLAGS.graphdef.endswith(".pbtxt"):
text_format.Merge(graph_file.read(), graph_def)
else:
graph_def.ParseFromString(graph_file.read())
importer.import_graph_def(graph_def, name="")
graph = ops.get_default_graph()
fetch = graph.get_operation_by_name(FLAGS.fetch)
graph.add_to_collection("train_op", fetch)
metagraph = saver.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)
if FLAGS.rewriter_config is not None:
rewriter_config = rewriter_config_pb2.RewriterConfig()
text_format.Merge(FLAGS.rewriter_config, rewriter_config)
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
metagraph.graph_def.CopyFrom(optimized_graph)
report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
print(report)
示例2: testDefaultAttrsRemoved
def testDefaultAttrsRemoved(self):
producer_op_list = op_def_pb2.OpList()
text_format.Merge("""
op {
name: 'OpWithFutureDefaultAttr'
attr { name: 'default_int' type: 'int' default_value { i: 456 } }
}
""", producer_op_list)
# Attr only in producer_op_list with default value gets removed.
with ops.Graph().as_default():
a = importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'OpWithFutureDefaultAttr'
attr { key: 'default_int' value { i: 456 } } }
"""),
return_elements=["A"],
producer_op_list=producer_op_list)
with self.assertRaisesRegexp(ValueError, "No attr named 'default_int'"):
a[0].get_attr("default_int")
# Attr only in producer_op_list with non-default value is preserved.
with ops.Graph().as_default():
a = importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'OpWithFutureDefaultAttr'
attr { key: 'default_int' value { i: 987 } } }
"""),
return_elements=["A"],
producer_op_list=producer_op_list)
self.assertEqual(987, a[0].get_attr("default_int"))
示例3: get_metagraph
def get_metagraph():
"""Constructs and returns a MetaGraphDef from the input file."""
if FLAGS.metagraphdef:
with gfile.GFile(FLAGS.metagraphdef) as meta_file:
metagraph = meta_graph_pb2.MetaGraphDef()
if FLAGS.metagraphdef.endswith(".pbtxt"):
text_format.Merge(meta_file.read(), metagraph)
else:
metagraph.ParseFromString(meta_file.read())
if FLAGS.fetch is not None:
fetch_collection = meta_graph_pb2.CollectionDef()
for fetch in FLAGS.fetch.split(","):
fetch_collection.node_list.value.append(fetch)
metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
else:
with gfile.GFile(FLAGS.graphdef) as graph_file:
graph_def = graph_pb2.GraphDef()
if FLAGS.graphdef.endswith(".pbtxt"):
text_format.Merge(graph_file.read(), graph_def)
else:
graph_def.ParseFromString(graph_file.read())
importer.import_graph_def(graph_def, name="")
graph = ops.get_default_graph()
for fetch in FLAGS.fetch.split(","):
fetch_op = graph.get_operation_by_name(fetch)
graph.add_to_collection("train_op", fetch_op)
metagraph = saver.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)
return metagraph
示例4: testWithDeviceFunctionDependingOnInputs
def testWithDeviceFunctionDependingOnInputs(self):
if ops._USE_C_API: return # TODO(skyewm): make this work with C API
with ops.Graph().as_default() as g:
with ops.device("/job:ps"):
v1 = constant_op.constant(1.0)
v2 = constant_op.constant(1.0)
_ = v1 + v2
_ = v1 - v2
_ = array_ops.identity(v1)
gdef = g.as_graph_def()
# We'll use the following device function to observe ops with two inputs.
ops_with_two_inputs = []
def InputCounter(op):
if len(op.inputs) == 2:
ops_with_two_inputs.append(op)
return ""
with ops.Graph().as_default() as g:
with ops.device(InputCounter):
importer.import_graph_def(gdef)
# We expect to see the add and subtract, but not identity.
self.assertEqual(2, len(ops_with_two_inputs))
示例5: run_graph_def
def run_graph_def(graph_def, input_map, outputs):
graph = ops_lib.Graph()
with graph.as_default():
importer.import_graph_def(graph_def, input_map={}, name="")
with session.Session(graph=graph) as sess:
results = sess.run(outputs, feed_dict=input_map)
return results
示例6: testInvalidInputForInputMap
def testInvalidInputForInputMap(self):
with ops.Graph().as_default():
with self.assertRaises(TypeError) as e:
importer.import_graph_def(
self._MakeGraphDef(""), input_map=[constant_op.constant(5.0)])
self.assertEqual("input_map must be a dictionary mapping strings to "
"Tensor objects.", str(e.exception))
graph_def = self._MakeGraphDef("""
node { name: 'a' op: 'Placeholder'
attr { key: 'dtype' value { type: DT_FLOAT } }}
node { name: 'id' op: 'Identity' input: 'a:0'
attr { key: 'T' value { type: DT_FLOAT } }}""")
with ops.Graph().as_default():
with self.assertRaises(ValueError) as e:
importer.import_graph_def(
graph_def,
input_map={"a:0": variables.Variable(5.0)},
name="")
self.assertStartsWith(str(e.exception),
"tf.import_graph_def() requires a non-empty `name` "
"if `input_map` contains non-Tensor values.")
with ops.Graph().as_default():
t, = importer.import_graph_def(
graph_def,
input_map={"a:0": constant_op.constant(5.0)},
name="",
return_elements=["id:0"])
with self.test_session():
self.assertEqual(5.0, t.eval())
示例7: testImportGraphWithFunctionTwice
def testImportGraphWithFunctionTwice(self):
g = ops.Graph()
with g.as_default():
@function.Defun()
def Add2(x, y):
return math_ops.add(x, y)
x = array_ops.placeholder(dtype=dtypes.float32, name="x")
y = array_ops.placeholder(dtype=dtypes.float32, name="y")
_ = Add2(x, y, name="z") # pylint: disable=unexpected-keyword-arg
gdef = g.as_graph_def()
x = random_ops.random_uniform(dtype=dtypes.float32, shape=())
y = random_ops.random_uniform(dtype=dtypes.float32, shape=())
input_map = {"x:0": x, "y:0": y}
with ops.name_scope("first"):
z1 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
with ops.name_scope("second"):
z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
with self.test_session() as sess:
z1_val, z2_val = sess.run((z1, z2))
self.assertAllEqual(z1_val, z2_val)
示例8: testNamePrefixColocationAttrsMultipleImport
def testNamePrefixColocationAttrsMultipleImport(self):
if ops._USE_C_API: return # TODO(skyewm): set uniquify_names
original_graph_def = self._MakeGraphDef("""
node { name: 'A' op: 'None' }
node { name: 'B' op: 'None' attr {
key: '_class'
value { list { s: 'loc:@A' } }
} }""")
with ops.Graph().as_default():
b, = importer.import_graph_def(
original_graph_def, return_elements=["B"], name="")
_, = importer.import_graph_def(
original_graph_def, return_elements=["B"], name="")
self.assertProtoEqualsVersion("""
node { name: 'A' op: 'None' }
node { name: 'B' op: 'None' attr {
key: '_class'
value { list { s: 'loc:@A' } }
} }
node { name: 'A_1' op: 'None' }
node { name: 'B_1' op: 'None' attr {
key: '_class'
value { list { s: 'loc:@A_1' } }
} }""", b.graph.as_graph_def())
示例9: testMissingInputOpInGraphDef
def testMissingInputOpInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError) as e:
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'If' input: 'A:0' }
"""))
self.assertTrue("Input tensor 'A:0' not found" in str(e.exception))
示例10: testInvalidTensorNameInGraphDef
def testInvalidTensorNameInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
"Node 'B': Unknown input node 'A:B:0'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'None' input: 'A:B:0' }
"""))
示例11: testMissingControlInputInGraphDef
def testMissingControlInputInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
r"Node 'B': Unknown input node '\^A'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'None' input: '^A' }
"""))
示例12: testMissingInputOpInGraphDef
def testMissingInputOpInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
"Node 'B': Unknown input node 'A:0'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'FloatInput' input: 'A:0' }
"""))
示例13: testVersionHigh
def testVersionHigh(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError,
r"GraphDef min consumer version %d above current version %d "
r"for TensorFlow \S+\. Please upgrade TensorFlow\.$" %
(1 << 30, versions.GRAPH_DEF_VERSION)):
importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30))
示例14: testVersionLow
def testVersionLow(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(
Exception,
r"GraphDef producer version -1 below min producer %d supported "
r"by TensorFlow \S+\. Please regenerate your graph.$" %
versions.GRAPH_DEF_VERSION_MIN_PRODUCER):
importer.import_graph_def(self._MakeGraphDef("", producer=-1))
示例15: testMissingControlInputInGraphDef
def testMissingControlInputInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError) as e:
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'None' input: '^A' }
"""))
self.assertTrue("Control input '^A' not found" in str(e.exception))