当前位置: 首页>>代码示例>>Python>>正文


Python saver.import_meta_graph函数代码示例

本文整理汇总了Python中tensorflow.python.training.saver.import_meta_graph函数的典型用法代码示例。如果您正苦于以下问题:Python import_meta_graph函数的具体用法?Python import_meta_graph怎么用?Python import_meta_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了import_meta_graph函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _read_vars

 def _read_vars(self, model_dir):
   """Returns (global_step, latest_feature)."""
   with ops.Graph().as_default() as g:
     ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
     meta_filename = ckpt_path + '.meta'
     saver_lib.import_meta_graph(meta_filename)
     saver = saver_lib.Saver()
     with self.test_session(graph=g) as sess:
       saver.restore(sess, ckpt_path)
       return sess.run(ops.get_collection('my_vars'))
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:10,代码来源:iterator_ops_test.py

示例2: testMetaGraphSaveLoad

  def testMetaGraphSaveLoad(self):
    save_prefix = os.path.join(self.get_temp_dir(), "ckpt")
    save_graph = ops.Graph()
    with save_graph.as_default(), self.test_session(
        graph=save_graph) as session:
      partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
      with variable_scope.variable_scope("root", partitioner=partitioner):
        v0 = variable_scope.get_variable(
            "v0", dtype=dtypes.float32, shape=(10, 10))
        v0_list = v0._get_variable_list()
        v0_part = v0._get_partitions()
        self.assertEqual(len(v0_list), 5)
        self.assertAllEqual(v0_part, (5, 1))
        variables.global_variables_initializer().run()

        save_graph.get_collection_ref("partvar").append(v0)
        saver = saver_lib.Saver()
        save_graph.finalize()
        save_path = saver.save(sess=session, save_path=save_prefix)
        previous_value = session.run(
            save_graph.get_tensor_by_name(v0.name + ":0"))

    restore_graph = ops.Graph()
    with restore_graph.as_default(), self.test_session(
        graph=restore_graph) as session:
      saver = saver_lib.import_meta_graph(save_path + ".meta")
      saver.restore(sess=session, save_path=save_path)
      v0, = save_graph.get_collection_ref("partvar")
      self.assertIsInstance(v0, variables.PartitionedVariable)
      self.assertAllEqual(
          previous_value,
          session.run(restore_graph.get_tensor_by_name(v0.name + ":0")))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:32,代码来源:partitioned_variables_test.py

示例3: graph_def_from_checkpoint

def graph_def_from_checkpoint(checkpoint_dir, output_node_names):
  """Converts checkpoint data to GraphDef.

  Reads the latest checkpoint data and produces a GraphDef in which the
  variables have been converted to constants.

  Args:
    checkpoint_dir: Path to the checkpoints.
    output_node_names: List of name strings for the result nodes of the graph.

  Returns:
    A GraphDef from the latest checkpoint

  Raises:
    ValueError: if no checkpoint is found
  """
  checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir)
  if checkpoint_path is None:
    raise ValueError('Could not find a checkpoint at: {0}.'
                     .format(checkpoint_dir))

  saver_for_restore = saver_lib.import_meta_graph(
      checkpoint_path + '.meta', clear_devices=True)
  with session.Session() as sess:
    saver_for_restore.restore(sess, checkpoint_path)
    graph_def = ops.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, graph_def, output_node_names)

  return output_graph_def
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:strip_pruning_vars_lib.py

示例4: _ExportAndImportGraph

 def _ExportAndImportGraph(self, graph):
   """Export and import graph into a new graph."""
   meta_graph = saver_lib.export_meta_graph(
       graph=graph, collection_list=graph.get_all_collection_keys())
   graph_copy = ops.Graph()
   with graph_copy.as_default():
     _ = saver_lib.import_meta_graph(meta_graph)
   return graph_copy
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:moving_averages_test.py

示例5: _CopyGraph

 def _CopyGraph(self, graph):
   """Return a copy of graph."""
   meta_graph = saver_lib.export_meta_graph(
       graph=graph, collection_list=graph.get_all_collection_keys())
   graph_copy = ops.Graph()
   with graph_copy.as_default():
     _ = saver_lib.import_meta_graph(meta_graph)
   return graph_copy
开发者ID:Eagle732,项目名称:tensorflow,代码行数:8,代码来源:fold_batch_norms_test.py

示例6: load

def load(sess, tags, export_dir):
  """Loads the model from a SavedModel as specified by tags.

  Args:
    sess: The TensorFlow session to restore the variables.
    tags: Set of string tags to identify the required MetaGraphDef. These should
        correspond to the tags used when saving the variables using the
        SavedModel `save()` API.
    export_dir: Directory in which the SavedModel protocol buffer and variables
        to be loaded are located.

  Returns:
    The `MetaGraphDef` protocol buffer loaded in the provided session. This
    can be used to further extract signature-defs, collection-defs, etc.

  Raises:
    RuntimeError: MetaGraphDef associated with the tags cannot be found.
  """
  # Build the SavedModel protocol buffer and find the requested meta graph def.
  saved_model = _parse_saved_model(export_dir)
  found_match = False
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == set(tags):
      meta_graph_def_to_load = meta_graph_def
      found_match = True
      break

  if not found_match:
    raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
        "[]") + " could not be found in SavedModel")

  # Build a saver by importing the meta graph def to load.
  saver = tf_saver.import_meta_graph(meta_graph_def_to_load)

  # Build the checkpoint path where the variables are located.
  variables_path = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.VARIABLES_DIRECTORY),
      compat.as_bytes(constants.VARIABLES_FILENAME))

  # Restore the variables using the built saver in the provided session.
  saver.restore(sess, variables_path)

  # Get asset tensors, if any.
  asset_tensors_dictionary = _get_asset_tensors(export_dir,
                                                meta_graph_def_to_load)

  main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
  if main_op_tensor is not None:
    sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
  else:
    legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
    if legacy_init_op_tensor is not None:
      sess.run(fetches=[legacy_init_op_tensor],
               feed_dict=asset_tensors_dictionary)

  return meta_graph_def_to_load
开发者ID:curtiszimmerman,项目名称:tensorflow,代码行数:57,代码来源:loader.py

示例7: testMetagraph

  def testMetagraph(self):
    with ops.Graph().as_default():
      with variable_scope.variable_scope("foo", use_resource=True):
        a = variable_scope.get_variable("a", initializer=10.0)

      momentum.MomentumOptimizer(
          learning_rate=0.001, momentum=0.1).minimize(
              a,
              colocate_gradients_with_ops=True,
              global_step=training_util.get_or_create_global_step())

      graph = ops.get_default_graph()
      meta_graph_def = saver.export_meta_graph(graph=graph)

    with ops.Graph().as_default():
      saver.import_meta_graph(meta_graph_def, import_scope="")
      meta_graph_two = saver.export_meta_graph(graph=graph)
    self.assertEqual(meta_graph_def, meta_graph_two)
开发者ID:aeverall,项目名称:tensorflow,代码行数:18,代码来源:resource_variable_ops_test.py

示例8: testGradientOfDeserializedCond

  def testGradientOfDeserializedCond(self):
    with ops.Graph().as_default():
      pred = array_ops.placeholder(dtypes.bool, name="pred")
      x = constant_op.constant(3.0, name="x")
      ops.add_to_collection("x", x)

      def true_fn():
        return math_ops.pow(x, 3)

      def false_fn():
        return x

      ops.add_to_collection("pred", pred)
      cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
      for c in cond:
        ops.add_to_collection("cond", c)
      meta_graph = saver.export_meta_graph()

    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        saver.import_meta_graph(meta_graph)
        x = ops.get_collection("x")[0]
        pred = ops.get_collection("pred")[0]
        cond = ops.get_collection("cond")
        cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")
        cond_grad_grad = gradients_impl.gradients(
            cond_grad, [x], name="cond_grad_grad")
        # d[x^3]/dx = 3x^2
        true_val = sess.run(cond_grad, {pred: True})
        self.assertEqual(true_val, [27.0])
        # d[x]/dx = 1
        false_val = sess.run(cond_grad, {pred: False})
        self.assertEqual(false_val, [1.0])

        true_val = sess.run(cond_grad_grad, {pred: True})
        # d2[x^3]/dx2 = 6x
        self.assertEqual(true_val, [18.0])
        false_val = sess.run(cond_grad_grad, {pred: False})
        # d2[x]/dx2 = 0
        self.assertEqual(false_val, [0.0])
开发者ID:clsung,项目名称:tensorflow,代码行数:40,代码来源:cond_v2_test.py

示例9: _get_default_signature

  def _get_default_signature(self, export_meta_filename):
    """ Gets the default signature from the export.meta file. """
    with session.Session():
      save = saver.import_meta_graph(export_meta_filename)
      meta_graph_def = save.export_meta_graph()
      collection_def = meta_graph_def.collection_def

      signatures_any = collection_def['serving_signatures'].any_list.value
      self.assertEquals(len(signatures_any), 1)
      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      default_signature = signatures.default_signature
      return default_signature
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:13,代码来源:export_test.py

示例10: testNoVariables

  def testNoVariables(self):
    test_dir = _TestDir("no_variables")
    filename = os.path.join(test_dir, "metafile")

    input_feed_value = -10  # Arbitrary input value for feed_dict.

    orig_graph = tf.Graph()
    with self.test_session(graph=orig_graph) as sess:
      # Create a minimal graph with zero variables.
      input_tensor = tf.placeholder(tf.float32, shape=[], name="input")
      offset = tf.constant(42, dtype=tf.float32, name="offset")
      output_tensor = tf.add(input_tensor, offset, name="add_offset")

      # Add input and output tensors to graph collections.
      tf.add_to_collection("input_tensor", input_tensor)
      tf.add_to_collection("output_tensor", output_tensor)

      output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
      self.assertEqual(output_value, 32)

      # Generates MetaGraphDef.
      #
      # Note that this is calling the saver *module-level* export_meta_graph and
      # not the Saver.export_meta_graph instance-level method.
      meta_graph_def = saver_module.export_meta_graph(
          filename=filename,
          graph_def=tf.get_default_graph().as_graph_def(),
          collection_list=["input_tensor", "output_tensor"],
          saver_def=None,
      )

    # Create a clean graph and import the MetaGraphDef nodes.
    new_graph = tf.Graph()
    with self.test_session(graph=new_graph) as sess:
      # Import the previously export meta graph.
      saver_instance = saver_module.import_meta_graph(filename)
      # The saver instance should be None since there are no graph variables
      # to be restored in this case.
      self.assertIsNone(saver_instance)

      # Re-exports the current graph state for comparison to the original.
      new_meta_graph_def = saver_module.export_meta_graph(filename + "_new")
      self.assertProtoEquals(meta_graph_def, new_meta_graph_def)

      # Ensures that we can still get a reference to our graph collections.
      new_input_tensor = tf.get_collection("input_tensor")[0]
      new_output_tensor = tf.get_collection("output_tensor")[0]
      # Verifies that the new graph computes the same result as the original.
      new_output_value = sess.run(
          new_output_tensor, {new_input_tensor: input_feed_value})
      self.assertEqual(new_output_value, output_value)
开发者ID:2er0,项目名称:tensorflow,代码行数:51,代码来源:saver_test.py

示例11: testRestoreFromMetaGraph

 def testRestoreFromMetaGraph(self):
   logdir = self._test_dir("restore_from_meta_graph")
   with ops.Graph().as_default():
     variables.VariableV1(1, name="v0")
     sv = supervisor.Supervisor(logdir=logdir)
     sess = sv.prepare_or_wait_for_session("")
     filename = sv.saver.save(sess, sv.save_path)
     sv.stop()
   # Create a new Graph and Supervisor and recover.
   with ops.Graph().as_default():
     new_saver = saver_lib.import_meta_graph(".".join([filename, "meta"]))
     self.assertIsNotNone(new_saver)
     sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver)
     sess = sv2.prepare_or_wait_for_session("")
     self.assertEquals(1, sess.run("v0:0"))
     sv2.saver.save(sess, sv2.save_path)
     sv2.stop()
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:17,代码来源:supervisor_test.py

示例12: _testSaveRestoreUtility

  def _testSaveRestoreUtility(self, start, break_range, stop):
    path = self._iterator_checkpoint_prefix()
    step = 0
    meta_filename = path + "-%d.meta" % step

    input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
        np.array([[12], [13], [14], [15]]), 4))
    to_concatenate_components = (np.tile(
        np.array([[5], [6], [7], [8], [9]]), 20), np.tile(
            np.array([[16], [17], [18], [19], [20]]), 15))

    with ops.Graph().as_default() as g:
      init_op, get_next = self._build_graph(input_components,
                                            to_concatenate_components)
      saver = saver_lib.Saver()
      with self.test_session(graph=g) as sess:
        sess.run(init_op)
        for i in range(start, break_range):
          result = sess.run(get_next)
          if i < 4:
            for component, result_component in zip(input_components, result):
              self.assertAllEqual(component[i], result_component)
          else:
            for component, result_component in zip(to_concatenate_components,
                                                   result):
              self.assertAllEqual(component[i - 4], result_component)
        saver.save(sess, path, step)

    with ops.Graph().as_default() as g:
      saver = saver_lib.import_meta_graph(meta_filename)
      with self.test_session(graph=g) as sess:
        get_next = nest.pack_sequence_as(("a", "b"),
                                         ops.get_collection("get_next"))
        saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
        for i in range(break_range, stop):
          result = sess.run(get_next)
          if i < 4:
            for component, result_component in zip(input_components, result):
              self.assertAllEqual(component[i], result_component)
          else:
            for component, result_component in zip(to_concatenate_components,
                                                   result):
              self.assertAllEqual(component[i - 4], result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:45,代码来源:concatenate_dataset_op_test.py

示例13: testSaveRestoreUsingSaverFromMetaGraph

  def testSaveRestoreUsingSaverFromMetaGraph(self):

    def _build_graph(start, stop):
      iterator = dataset_ops.Dataset.range(start,
                                           stop).make_initializable_iterator()
      init_op = iterator.initializer
      get_next = iterator.get_next()
      ops.add_to_collection("iterator_ops", init_op)
      ops.add_to_collection("iterator_ops", get_next)
      saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator)
      # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection
      # so that it can be automatically picked up by the Saver.
      ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
      saver = saver_lib.Saver()
      return init_op, get_next, saver

    start = 2
    stop = 10
    break_point = 5
    path = self._iterator_checkpoint_prefix()
    meta_filename = path + ".meta"

    # Execute input pipeline for a few steps and save iterator state.
    with ops.Graph().as_default() as g:
      init_op, get_next, saver = _build_graph(start, stop)
      with self.test_session(graph=g) as sess:
        sess.run(variables.global_variables_initializer())
        sess.run(init_op)
        for i in range(start, break_point):
          self.assertEqual(i, sess.run(get_next))
        saver.save(sess, path)

    # Build the saver from the MetaGraph using import_meta_graph and
    # check that the iterator state is restored.
    with ops.Graph().as_default() as g:
      saver = saver_lib.import_meta_graph(meta_filename)
      init_op, get_next = ops.get_collection("iterator_ops")
      with self.test_session(graph=g) as sess:
        saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
        for i in range(break_point, stop):
          self.assertEqual(i, sess.run(get_next))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)
开发者ID:SylChan,项目名称:tensorflow,代码行数:43,代码来源:range_dataset_op_test.py

示例14: freeze_graph_with_def_protos

def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants."""
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not checkpoint_management.checkpoint_exists(input_checkpoint)):
    print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_parition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_parition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          print("Models containing partition variables cannot be converted "
                "from checkpoint files. Please pass in a SavedModel using "
                "the flag --input_saved_model_dir.")
          return -1
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

#.........这里部分代码省略.........
开发者ID:sonnyhu,项目名称:tensorflow,代码行数:101,代码来源:freeze_graph.py

示例15: _import_meta_graph

 def _import_meta_graph(self):
   meta_file_path = self._ckpt_path() + ".meta"
   return saver_lib.import_meta_graph(meta_file_path)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:3,代码来源:dataset_serialization_test_base.py


注:本文中的tensorflow.python.training.saver.import_meta_graph函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。