当前位置: 首页>>代码示例>>Python>>正文


Python saver.export_meta_graph函数代码示例

本文整理汇总了Python中tensorflow.python.training.saver.export_meta_graph函数的典型用法代码示例。如果您正苦于以下问题:Python export_meta_graph函数的具体用法?Python export_meta_graph怎么用?Python export_meta_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了export_meta_graph函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testNoVariables

  def testNoVariables(self):
    test_dir = _TestDir("no_variables")
    filename = os.path.join(test_dir, "metafile")

    input_feed_value = -10  # Arbitrary input value for feed_dict.

    orig_graph = tf.Graph()
    with self.test_session(graph=orig_graph) as sess:
      # Create a minimal graph with zero variables.
      input_tensor = tf.placeholder(tf.float32, shape=[], name="input")
      offset = tf.constant(42, dtype=tf.float32, name="offset")
      output_tensor = tf.add(input_tensor, offset, name="add_offset")

      # Add input and output tensors to graph collections.
      tf.add_to_collection("input_tensor", input_tensor)
      tf.add_to_collection("output_tensor", output_tensor)

      output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
      self.assertEqual(output_value, 32)

      # Generates MetaGraphDef.
      #
      # Note that this is calling the saver *module-level* export_meta_graph and
      # not the Saver.export_meta_graph instance-level method.
      meta_graph_def = saver_module.export_meta_graph(
          filename=filename,
          graph_def=tf.get_default_graph().as_graph_def(),
          collection_list=["input_tensor", "output_tensor"],
          saver_def=None,
      )

    # Create a clean graph and import the MetaGraphDef nodes.
    new_graph = tf.Graph()
    with self.test_session(graph=new_graph) as sess:
      # Import the previously export meta graph.
      saver_instance = saver_module.import_meta_graph(filename)
      # The saver instance should be None since there are no graph variables
      # to be restored in this case.
      self.assertIsNone(saver_instance)

      # Re-exports the current graph state for comparison to the original.
      new_meta_graph_def = saver_module.export_meta_graph(filename + "_new")
      self.assertProtoEquals(meta_graph_def, new_meta_graph_def)

      # Ensures that we can still get a reference to our graph collections.
      new_input_tensor = tf.get_collection("input_tensor")[0]
      new_output_tensor = tf.get_collection("output_tensor")[0]
      # Verifies that the new graph computes the same result as the original.
      new_output_value = sess.run(
          new_output_tensor, {new_input_tensor: input_feed_value})
      self.assertEqual(new_output_value, output_value)
开发者ID:2er0,项目名称:tensorflow,代码行数:51,代码来源:saver_test.py

示例2: main

def main(_):
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      metagraph.ParseFromString(meta_file.read())
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      fetch = graph.get_operation_by_name(FLAGS.fetch)
      graph.add_to_collection("train_op", fetch)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)

  if FLAGS.rewriter_config is not None:
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

  report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
  print(report)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:27,代码来源:cost_analyzer_tool.py

示例3: get_metagraph

def get_metagraph():
  """Constructs and returns a MetaGraphDef from the input file."""
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      if FLAGS.metagraphdef.endswith(".pbtxt"):
        text_format.Merge(meta_file.read(), metagraph)
      else:
        metagraph.ParseFromString(meta_file.read())
    if FLAGS.fetch is not None:
      fetch_collection = meta_graph_pb2.CollectionDef()
      for fetch in FLAGS.fetch.split(","):
        fetch_collection.node_list.value.append(fetch)
      metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      for fetch in FLAGS.fetch.split(","):
        fetch_op = graph.get_operation_by_name(fetch)
        graph.add_to_collection("train_op", fetch_op)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)
  return metagraph
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:29,代码来源:cost_analyzer_tool.py

示例4: testGradient

  def testGradient(self):
    if not test.is_gpu_available(cuda_only=True):
      self.skipTest('GPU required')

    random_seed.set_random_seed(0)
    x = random_ops.truncated_normal([1, 200, 200, 3], seed=0)
    y = conv_layers.conv2d(x, 32, [3, 3])
    z = conv_layers.conv2d(y, 32, [3, 3])
    optimizer = gradient_descent.GradientDescentOptimizer(1e-4)
    loss = math_ops.reduce_mean(z)
    train_op = optimizer.minimize(loss)
    graph = ops.get_default_graph()
    graph.add_to_collection('train_op', train_op)
    meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def())

    rewrite_options = rewriter_config_pb2.RewriterConfig(
        optimize_tensor_layout=True)
    optimized_graph = tf_optimizer.OptimizeGraph(rewrite_options, meta_graph)

    found = 0
    for node in optimized_graph.node:
      if node.op in ['Conv2D', 'Conv2DBackpropFilter', 'Conv2DBackpropInput']:
        found += 1
        self.assertEqual(node.attr['data_format'].s, 'NCHW')
    self.assertEqual(found, 5)
开发者ID:SylChan,项目名称:tensorflow,代码行数:25,代码来源:layout_optimizer_test.py

示例5: _run_inline_graph_optimization

def _run_inline_graph_optimization(func):
  """Apply function inline optimization to the graph.

  Returns the GraphDef after Grappler's function inlining optimization is
  applied. This optimization does not work on models with control flow.

  Args:
    func: ConcreteFunction.

  Returns:
    GraphDef
  """
  meta_graph = export_meta_graph(
      graph_def=func.graph.as_graph_def(), graph=func.graph)

  # Add a collection 'train_op' so that Grappler knows the outputs.
  fetch_collection = meta_graph_pb2.CollectionDef()
  for array in func.inputs + func.outputs:
    fetch_collection.node_list.value.append(array.name)
  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

  # Initialize RewriterConfig with everything disabled except function inlining.
  config = config_pb2.ConfigProto()
  rewrite_options = config.graph_options.rewrite_options
  rewrite_options.optimizers.append("function")
  return tf_optimizer.OptimizeGraph(config, meta_graph)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:26,代码来源:convert_to_constants.py

示例6: _ExportAndImportGraph

 def _ExportAndImportGraph(self, graph):
   """Export and import graph into a new graph."""
   meta_graph = saver_lib.export_meta_graph(
       graph=graph, collection_list=graph.get_all_collection_keys())
   graph_copy = ops.Graph()
   with graph_copy.as_default():
     _ = saver_lib.import_meta_graph(meta_graph)
   return graph_copy
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:moving_averages_test.py

示例7: _CopyGraph

 def _CopyGraph(self, graph):
   """Return a copy of graph."""
   meta_graph = saver_lib.export_meta_graph(
       graph=graph, collection_list=graph.get_all_collection_keys())
   graph_copy = ops.Graph()
   with graph_copy.as_default():
     _ = saver_lib.import_meta_graph(meta_graph)
   return graph_copy
开发者ID:Eagle732,项目名称:tensorflow,代码行数:8,代码来源:fold_batch_norms_test.py

示例8: testMetagraph

  def testMetagraph(self):
    with ops.Graph().as_default():
      with variable_scope.variable_scope("foo", use_resource=True):
        a = variable_scope.get_variable("a", initializer=10.0)

      momentum.MomentumOptimizer(
          learning_rate=0.001, momentum=0.1).minimize(
              a,
              colocate_gradients_with_ops=True,
              global_step=training_util.get_or_create_global_step())

      graph = ops.get_default_graph()
      meta_graph_def = saver.export_meta_graph(graph=graph)

    with ops.Graph().as_default():
      saver.import_meta_graph(meta_graph_def, import_scope="")
      meta_graph_two = saver.export_meta_graph(graph=graph)
    self.assertEqual(meta_graph_def, meta_graph_two)
开发者ID:aeverall,项目名称:tensorflow,代码行数:18,代码来源:resource_variable_ops_test.py

示例9: _convert_graph_def

  def _convert_graph_def(self):
    """Convert the input GraphDef."""
    graph = ops.Graph()
    with graph.as_default():
      importer.import_graph_def(self._input_graph_def, name="")
    self._grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
    self._add_nodes_blacklist()

    self._run_conversion()
开发者ID:aritratony,项目名称:tensorflow,代码行数:10,代码来源:trt_convert.py

示例10: setUp

  def setUp(self):
    self.base_path = os.path.join(test.get_temp_dir(), "no_vars")
    if not os.path.exists(self.base_path):
      os.mkdir(self.base_path)

    # Create a simple graph with a variable, then convert variables to
    # constants and export the graph.
    with ops.Graph().as_default() as g:
      x = array_ops.placeholder(dtypes.float32, name="x")
      w = variables.Variable(3.0)
      y = math_ops.subtract(w * x, 7.0, name="y")  # pylint: disable=unused-variable
      ops.add_to_collection("meta", "this is meta")

      with self.session(graph=g) as session:
        variables.global_variables_initializer().run()
        new_graph_def = graph_util.convert_variables_to_constants(
            session, g.as_graph_def(), ["y"])

      filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME)
      saver.export_meta_graph(
          filename, graph_def=new_graph_def, collection_list=["meta"])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:21,代码来源:session_bundle_test.py

示例11: _simple_metagraph

def _simple_metagraph(depthwise=False):
  random_seed.set_random_seed(0)
  x = variables.Variable(random_ops.truncated_normal([1, 200, 200, 3], seed=0))
  conv = conv_layers.separable_conv2d if depthwise else conv_layers.conv2d
  y = conv(x, 32, [3, 3])
  z = conv(y, 32, [3, 3])
  optimizer = gradient_descent.GradientDescentOptimizer(1e-4)
  loss = math_ops.reduce_mean(z)
  train_op = optimizer.minimize(loss)
  graph = ops.get_default_graph()
  graph.add_to_collection('train_op', train_op)
  meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def())
  return meta_graph
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:13,代码来源:layout_optimizer_test.py

示例12: test_meta_graph_transform

  def test_meta_graph_transform(self):

    with ops.Graph().as_default():
      with tf_session.Session(''):
        a = array_ops.placeholder(dtypes.int64, [1], name='a')
        b = array_ops.placeholder(dtypes.int64, [1], name='b')
        c = array_ops.placeholder(dtypes.int64, [1], name='c')
        _ = a * b
        _ = b * c
        base_meta_graph_def = saver.export_meta_graph()

    with ops.Graph().as_default():
      with tf_session.Session(''):
        a = array_ops.placeholder(dtypes.int64, [1], name='a')
        b = array_ops.placeholder(dtypes.int64, [1], name='b')
        _ = a * b
        meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
        meta_info_def.tags.append('tag_ab')

        expected_meta_graph_def = saver.export_meta_graph(
            meta_info_def=meta_info_def)
        # Graph rewriter clears versions field, so we expect that.
        expected_meta_graph_def.graph_def.ClearField('versions')
        # Graph rewriter adds an empty library field, so we expect that.
        expected_meta_graph_def.graph_def.library.CopyFrom(
            function_pb2.FunctionDefLibrary())

    input_names = ['a', 'b']
    output_names = ['mul:0']
    transforms = ['strip_unused_nodes']
    tags = ['tag_ab']
    print('AAAAAA: {}'.format(base_meta_graph_def))
    transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
        base_meta_graph_def, input_names, output_names, transforms, tags)

    self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:36,代码来源:meta_graph_transform_test.py

示例13: _convert_saved_model_v2

  def _convert_saved_model_v2(self):
    """Convert the input SavedModel in 2.0 format."""
    self._saved_model = load.load(self._input_saved_model_dir,
                                  self._input_saved_model_tags)
    func = self._saved_model.signatures[self._input_saved_model_signature_key]
    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
    self._grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in func.inputs + func.outputs:
      fetch_collection.node_list.value.append(array.name)
    self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
        fetch_collection)

    # Run TRT optimizer in Grappler to convert the graph.
    self._run_conversion()

    def _get_tensor(graph, tensors):
      new_tensors = []
      for tensor in tensors:
        new_tensor = graph.get_tensor_by_name(tensor.name)
        new_tensor.set_shape(tensor.shape)
        new_tensors.append(new_tensor)
      return new_tensors

    # TODO(laigd): do we need to use different name e.g. "trt_func_graph"?
    converted_graph = func_graph.FuncGraph(func.graph.name)
    with converted_graph.as_default():
      importer.import_graph_def(self._converted_graph_def, name="")

    converted_graph.inputs = _get_tensor(converted_graph, func.graph.inputs)
    converted_graph.outputs = _get_tensor(converted_graph, func.graph.outputs)
    converted_graph.structured_outputs = func.graph.structured_outputs
    converted_graph.structured_input_signature = (
        func.graph.structured_input_signature)

    # pylint: disable=protected-access
    # TODO(laigd): should we set up the signature as well?
    self._converted_func = function.ConcreteFunction(
        converted_graph, attrs=None, signature=None)
    self._converted_func.add_to_graph()
    self._converted_func._arg_keywords = func._arg_keywords
    self._converted_func._num_positional_args = func._num_positional_args
    self._converted_func._captured_inputs = func._captured_inputs
    self._converted_func.graph.variables = func.graph.variables
开发者ID:perfmjs,项目名称:tensorflow,代码行数:47,代码来源:trt_convert.py

示例14: grappler_optimize

def grappler_optimize(graph, fetches=None, rewriter_config=None):
  """Tries to optimize the provided graph using grappler.

  Args:
    graph: A @{tf.Graph} instance containing the graph to optimize.
    fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away).
      Grappler uses the 'train_op' collection to look for fetches, so if not
      provided this collection should be non-empty.
    rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the
      graph.

  Returns:
    A @{tf.GraphDef} containing the rewritten graph.
  """
  if rewriter_config is None:
    rewriter_config = rewriter_config_pb2.RewriterConfig()
  if fetches is not None:
    for fetch in fetches:
      graph.add_to_collection('train_op', fetch)
  metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
  return tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:21,代码来源:test_util.py

示例15: _convert_graph_def

  def _convert_graph_def(self):
    """Convert the input GraphDef."""
    graph = ops.Graph()
    with graph.as_default():
      importer.import_graph_def(self._input_graph_def, name="")
    self._grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
    if self._nodes_blacklist:
      output_collection = meta_graph_pb2.CollectionDef()
      output_list = output_collection.node_list.value
      for i in self._nodes_blacklist:
        if isinstance(i, ops.Tensor):
          output_list.append(_to_bytes(i.name))
        else:
          output_list.append(_to_bytes(i))
      # TODO(laigd): use another key as the self._nodes_blacklist are really
      # not train_op.
      self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
          output_collection)

    self._run_conversion()
开发者ID:kylin9872,项目名称:tensorflow,代码行数:21,代码来源:trt_convert.py


注:本文中的tensorflow.python.training.saver.export_meta_graph函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。