当前位置: 首页>>代码示例>>Python>>正文


Python importer.import_graph_def函数代码示例

本文整理汇总了Python中tensorflow.python.framework.importer.import_graph_def函数的典型用法代码示例。如果您正苦于以下问题:Python import_graph_def函数的具体用法?Python import_graph_def怎么用?Python import_graph_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了import_graph_def函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

def main(_):
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      metagraph.ParseFromString(meta_file.read())
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      fetch = graph.get_operation_by_name(FLAGS.fetch)
      graph.add_to_collection("train_op", fetch)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)

  if FLAGS.rewriter_config is not None:
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

  report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
  print(report)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:27,代码来源:cost_analyzer_tool.py

示例2: testDefaultAttrsRemoved

  def testDefaultAttrsRemoved(self):
    producer_op_list = op_def_pb2.OpList()
    text_format.Merge("""
      op {
        name: 'OpWithFutureDefaultAttr'
        attr { name: 'default_int' type: 'int' default_value { i: 456 } }
      }
    """, producer_op_list)
    # Attr only in producer_op_list with default value gets removed.
    with ops.Graph().as_default():
      a = importer.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'OpWithFutureDefaultAttr'
                 attr { key: 'default_int' value { i: 456 } } }
          """),
          return_elements=["A"],
          producer_op_list=producer_op_list)
      with self.assertRaisesRegexp(ValueError, "No attr named 'default_int'"):
        a[0].get_attr("default_int")

    # Attr only in producer_op_list with non-default value is preserved.
    with ops.Graph().as_default():
      a = importer.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'OpWithFutureDefaultAttr'
                 attr { key: 'default_int' value { i: 987 } } }
          """),
          return_elements=["A"],
          producer_op_list=producer_op_list)
      self.assertEqual(987, a[0].get_attr("default_int"))
开发者ID:pcm17,项目名称:tensorflow,代码行数:30,代码来源:importer_test.py

示例3: get_metagraph

def get_metagraph():
  """Constructs and returns a MetaGraphDef from the input file."""
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      if FLAGS.metagraphdef.endswith(".pbtxt"):
        text_format.Merge(meta_file.read(), metagraph)
      else:
        metagraph.ParseFromString(meta_file.read())
    if FLAGS.fetch is not None:
      fetch_collection = meta_graph_pb2.CollectionDef()
      for fetch in FLAGS.fetch.split(","):
        fetch_collection.node_list.value.append(fetch)
      metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      for fetch in FLAGS.fetch.split(","):
        fetch_op = graph.get_operation_by_name(fetch)
        graph.add_to_collection("train_op", fetch_op)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)
  return metagraph
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:29,代码来源:cost_analyzer_tool.py

示例4: testWithDeviceFunctionDependingOnInputs

  def testWithDeviceFunctionDependingOnInputs(self):
    if ops._USE_C_API: return  # TODO(skyewm): make this work with C API

    with ops.Graph().as_default() as g:
      with ops.device("/job:ps"):
        v1 = constant_op.constant(1.0)
        v2 = constant_op.constant(1.0)
      _ = v1 + v2
      _ = v1 - v2
      _ = array_ops.identity(v1)
    gdef = g.as_graph_def()

    # We'll use the following device function to observe ops with two inputs.
    ops_with_two_inputs = []

    def InputCounter(op):
      if len(op.inputs) == 2:
        ops_with_two_inputs.append(op)
      return ""

    with ops.Graph().as_default() as g:
      with ops.device(InputCounter):
        importer.import_graph_def(gdef)

    # We expect to see the add and subtract, but not identity.
    self.assertEqual(2, len(ops_with_two_inputs))
开发者ID:dansbecker,项目名称:tensorflow,代码行数:26,代码来源:importer_test.py

示例5: run_graph_def

def run_graph_def(graph_def, input_map, outputs):
  graph = ops_lib.Graph()
  with graph.as_default():
    importer.import_graph_def(graph_def, input_map={}, name="")
  with session.Session(graph=graph) as sess:
    results = sess.run(outputs, feed_dict=input_map)
  return results
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:7,代码来源:quantize_graph_test.py

示例6: testInvalidInputForInputMap

 def testInvalidInputForInputMap(self):
   with ops.Graph().as_default():
     with self.assertRaises(TypeError) as e:
       importer.import_graph_def(
           self._MakeGraphDef(""), input_map=[constant_op.constant(5.0)])
     self.assertEqual("input_map must be a dictionary mapping strings to "
                      "Tensor objects.", str(e.exception))
   graph_def = self._MakeGraphDef("""
        node { name: 'a' op: 'Placeholder'
               attr { key: 'dtype' value { type: DT_FLOAT } }}
        node { name: 'id' op: 'Identity' input: 'a:0'
               attr { key: 'T' value { type: DT_FLOAT } }}""")
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           graph_def,
           input_map={"a:0": variables.Variable(5.0)},
           name="")
     self.assertStartsWith(str(e.exception),
                           "tf.import_graph_def() requires a non-empty `name` "
                           "if `input_map` contains non-Tensor values.")
   with ops.Graph().as_default():
     t, = importer.import_graph_def(
         graph_def,
         input_map={"a:0": constant_op.constant(5.0)},
         name="",
         return_elements=["id:0"])
     with self.test_session():
       self.assertEqual(5.0, t.eval())
开发者ID:pcm17,项目名称:tensorflow,代码行数:29,代码来源:importer_test.py

示例7: testImportGraphWithFunctionTwice

  def testImportGraphWithFunctionTwice(self):
    g = ops.Graph()
    with g.as_default():

      @function.Defun()
      def Add2(x, y):
        return math_ops.add(x, y)

      x = array_ops.placeholder(dtype=dtypes.float32, name="x")
      y = array_ops.placeholder(dtype=dtypes.float32, name="y")
      _ = Add2(x, y, name="z")  # pylint: disable=unexpected-keyword-arg

    gdef = g.as_graph_def()

    x = random_ops.random_uniform(dtype=dtypes.float32, shape=())
    y = random_ops.random_uniform(dtype=dtypes.float32, shape=())
    input_map = {"x:0": x, "y:0": y}

    with ops.name_scope("first"):
      z1 = importer.import_graph_def(gdef, return_elements=["z:0"],
                                     input_map=input_map)[0]

    with ops.name_scope("second"):
      z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
                                     input_map=input_map)[0]

    with self.test_session() as sess:
      z1_val, z2_val = sess.run((z1, z2))
      self.assertAllEqual(z1_val, z2_val)
开发者ID:clsung,项目名称:tensorflow,代码行数:29,代码来源:importer_test.py

示例8: testNamePrefixColocationAttrsMultipleImport

  def testNamePrefixColocationAttrsMultipleImport(self):
    if ops._USE_C_API: return  # TODO(skyewm): set uniquify_names

    original_graph_def = self._MakeGraphDef("""
          node { name: 'A' op: 'None' }
          node { name: 'B' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A' } }
          } }""")

    with ops.Graph().as_default():
      b, = importer.import_graph_def(
          original_graph_def, return_elements=["B"], name="")
      _, = importer.import_graph_def(
          original_graph_def, return_elements=["B"], name="")
      self.assertProtoEqualsVersion("""
          node { name: 'A' op: 'None' }
          node { name: 'B' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A' } }
          } }
          node { name: 'A_1' op: 'None' }
          node { name: 'B_1' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A_1' } }
          } }""", b.graph.as_graph_def())
开发者ID:dansbecker,项目名称:tensorflow,代码行数:26,代码来源:importer_test.py

示例9: testMissingInputOpInGraphDef

 def testMissingInputOpInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'If' input: 'A:0' }
           """))
     self.assertTrue("Input tensor 'A:0' not found" in str(e.exception))
开发者ID:pcm17,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py

示例10: testInvalidTensorNameInGraphDef

 def testInvalidTensorNameInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  "Node 'B': Unknown input node 'A:B:0'"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: 'A:B:0' }
           """))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py

示例11: testMissingControlInputInGraphDef

 def testMissingControlInputInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  r"Node 'B': Unknown input node '\^A'"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: '^A' }
           """))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py

示例12: testMissingInputOpInGraphDef

 def testMissingInputOpInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  "Node 'B': Unknown input node 'A:0'"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'FloatInput' input: 'A:0' }
           """))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py

示例13: testVersionHigh

 def testVersionHigh(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError,
         r"GraphDef min consumer version %d above current version %d "
         r"for TensorFlow \S+\.  Please upgrade TensorFlow\.$" %
         (1 << 30, versions.GRAPH_DEF_VERSION)):
       importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py

示例14: testVersionLow

 def testVersionLow(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         Exception,
         r"GraphDef producer version -1 below min producer %d supported "
         r"by TensorFlow \S+\.  Please regenerate your graph.$" %
         versions.GRAPH_DEF_VERSION_MIN_PRODUCER):
       importer.import_graph_def(self._MakeGraphDef("", producer=-1))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py

示例15: testMissingControlInputInGraphDef

 def testMissingControlInputInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: '^A' }
           """))
     self.assertTrue("Control input '^A' not found" in str(e.exception))
开发者ID:pcm17,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py


注:本文中的tensorflow.python.framework.importer.import_graph_def函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。