当前位置: 首页>>代码示例>>Python>>正文


Python checkpoint_management.latest_checkpoint函数代码示例

本文整理汇总了Python中tensorflow.python.training.checkpoint_management.latest_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python latest_checkpoint函数的具体用法?Python latest_checkpoint怎么用?Python latest_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了latest_checkpoint函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testRecoverSession

  def testRecoverSession(self):
    # Create a checkpoint.
    checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable(1, name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertFalse(initialized)
      sess.run(v.initializer)
      self.assertEquals(1, sess.run(v))
      saver.save(sess,
                 os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
    self._test_recovered_variable(
        checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
            checkpoint_dir))
    # Cannot set both checkpoint_dir and checkpoint_filename_with_path.
    with self.assertRaises(ValueError):
      self._test_recovered_variable(
          checkpoint_dir=checkpoint_dir,
          checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
              checkpoint_dir))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:31,代码来源:session_manager_test.py

示例2: _new_layer_weight_loading_test_template

  def _new_layer_weight_loading_test_template(
      self, first_model_fn, second_model_fn, restore_init_fn):
    with self.cached_session() as session:
      model = first_model_fn()
      temp_dir = self.get_temp_dir()
      prefix = os.path.join(temp_dir, 'ckpt')

      x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
      executing_eagerly = context.executing_eagerly()
      ref_y_tensor = model(x)
      if not executing_eagerly:
        session.run([v.initializer for v in model.variables])
      ref_y = self.evaluate(ref_y_tensor)
      model.save_weights(prefix)
      self.assertEqual(
          prefix,
          checkpoint_management.latest_checkpoint(temp_dir))
      for v in model.variables:
        self.evaluate(
            v.assign(random_ops.random_normal(shape=array_ops.shape(v))))

      self.addCleanup(shutil.rmtree, temp_dir)

      second_model = second_model_fn()
      second_model.load_weights(prefix)
      second_model(x)
      self.evaluate(restore_init_fn(second_model))
      second_model.save_weights(prefix)
      # Check that the second model's checkpoint loads into the original model
      model.load_weights(prefix)
      y = self.evaluate(model(x))
      self.assertAllClose(ref_y, y)
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:32,代码来源:hdf5_format_test.py

示例3: testUsageGraph

 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         model = MyModel()
         optimizer = adam.AdamOptimizer(0.001)
         root = util.Checkpoint(
             optimizer=optimizer, model=model,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             model(input_value),
             global_step=root.global_step)
         checkpoint_path = checkpoint_management.latest_checkpoint(
             checkpoint_directory)
         with self.session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
开发者ID:jackd,项目名称:tensorflow,代码行数:35,代码来源:checkpointable_utils_test.py

示例4: _restore_or_save_initial_ckpt

  def _restore_or_save_initial_ckpt(self, session):
    # Ideally this should be run in after_create_session but is not for the
    # following reason:
    # Currently there is no way of enforcing an order of running the
    # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
    # is run *after* this hook. That is troublesome because
    # 1. If a checkpoint exists and this hook restores it, the initializer hook
    #    will override it.
    # 2. If no checkpoint exists, this hook will try to save an uninitialized
    #    iterator which will result in an exception.
    #
    # As a temporary fix we enter the following implicit contract between this
    # hook and the _DatasetInitializerHook.
    # 1. The _DatasetInitializerHook initializes the iterator in the call to
    #    after_create_session.
    # 2. This hook saves the iterator on the first call to `before_run()`, which
    #    is guaranteed to happen after `after_create_session()` of all hooks
    #    have been run.

    # Check if there is an existing checkpoint. If so, restore from it.
    # pylint: disable=protected-access
    latest_checkpoint_path = checkpoint_management.latest_checkpoint(
        self._checkpoint_saver_hook._checkpoint_dir,
        latest_filename=self._latest_filename)
    if latest_checkpoint_path:
      self._checkpoint_saver_hook._get_saver().restore(session,
                                                       latest_checkpoint_path)
    else:
      # The checkpoint saved here is the state at step "global_step".
      # Note: We do not save the GraphDef or MetaGraphDef here.
      global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
      self._checkpoint_saver_hook._save(session, global_step)
      self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:33,代码来源:iterator_ops.py

示例5: export_fn

  def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None):
    """Exports the given Estimator as a SavedModel.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the most recent checkpoint found within the model directory is chosen.
      eval_result: placehold args matching the call signature of ExportStrategy.

    Returns:
      The string path to the exported directory.
    """
    if not checkpoint_path:
      # TODO(b/67425018): switch to
      #    checkpoint_path = estimator.latest_checkpoint()
      #  as soon as contrib is cleaned up and we can thus be sure that
      #  estimator is a tf.estimator.Estimator and not a
      #  tf.contrib.learn.Estimator
      checkpoint_path = checkpoint_management.latest_checkpoint(
          estimator.model_dir)
    export_checkpoint_path, export_eval_result = best_model_selector.update(
        checkpoint_path, eval_result)

    if export_checkpoint_path and export_eval_result is not None:
      checkpoint_base = os.path.basename(export_checkpoint_path)
      export_dir = os.path.join(export_dir_base, checkpoint_base)
      return best_model_export_strategy.export(
          estimator, export_dir, export_checkpoint_path, export_eval_result)
    else:
      return ''
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:32,代码来源:saved_model_export_utils.py

示例6: testAgnosticUsage

 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()), test_util.device(use_gpu=True):
       model = MyModel()
       optimizer = adam.AdamOptimizer(0.001)
       root = util.Checkpoint(
           optimizer=optimizer, model=model,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = checkpoint_management.latest_checkpoint(
           checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(model, input_value),
           global_step=root.global_step)
       if not context.executing_eagerly():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
开发者ID:jackd,项目名称:tensorflow,代码行数:32,代码来源:checkpointable_utils_test.py

示例7: wait_for_new_checkpoint

def wait_for_new_checkpoint(checkpoint_dir,
                            last_checkpoint=None,
                            seconds_to_sleep=1,
                            timeout=None):
  """Waits until a new checkpoint file is found.

  Args:
    checkpoint_dir: The directory in which checkpoints are saved.
    last_checkpoint: The last checkpoint path used or `None` if we're expecting
      a checkpoint for the first time.
    seconds_to_sleep: The number of seconds to sleep for before looking for a
      new checkpoint.
    timeout: The maximum number of seconds to wait. If left as `None`, then the
      process will wait indefinitely.

  Returns:
    a new checkpoint path, or None if the timeout was reached.
  """
  logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
  stop_time = time.time() + timeout if timeout is not None else None
  while True:
    checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
    if checkpoint_path is None or checkpoint_path == last_checkpoint:
      if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
        return None
      time.sleep(seconds_to_sleep)
    else:
      logging.info('Found new checkpoint at %s', checkpoint_path)
      return checkpoint_path
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:29,代码来源:evaluation.py

示例8: testGraphDistributionStrategy

  def testGraphDistributionStrategy(self):
    self.skipTest("b/121381184")
    num_training_steps = 10
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    def _train_fn(optimizer, model):
      input_value = constant_op.constant([[3.]])
      return optimizer.minimize(
          functools.partial(model, input_value),
          global_step=root.optimizer_step)

    for training_continuation in range(3):
      with ops.Graph().as_default():
        strategy = mirrored_strategy.MirroredStrategy()
        with strategy.scope():
          model = MyModel()
          optimizer = adam.AdamOptimizer(0.001)
          root = checkpointable_utils.Checkpoint(
              optimizer=optimizer, model=model,
              optimizer_step=training_util.get_or_create_global_step())
          status = root.restore(checkpoint_management.latest_checkpoint(
              checkpoint_directory))
          train_op = strategy.extended.call_for_each_replica(
              functools.partial(_train_fn, optimizer, model))
          with self.session() as session:
            if training_continuation > 0:
              status.assert_consumed()
            status.initialize_or_restore()
            for _ in range(num_training_steps):
              session.run(train_op)
            root.save(file_prefix=checkpoint_prefix)
        self.assertEqual((training_continuation + 1) * num_training_steps,
                         root.optimizer_step.numpy())
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:34,代码来源:util_with_v1_optimizers_test.py

示例9: testEagerTPUDistributionStrategy

  def testEagerTPUDistributionStrategy(self):
    self.skipTest("b/121387144")
    num_training_steps = 10
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    def _train_fn(optimizer, model):
      input_value = constant_op.constant([[3.]])
      optimizer.minimize(
          functools.partial(model, input_value),
          global_step=root.optimizer_step)

    for training_continuation in range(3):
      strategy = tpu_strategy.TPUStrategy()
      with strategy.scope():
        model = Subclassed()
        optimizer = adam_v1.AdamOptimizer(0.001)
        root = checkpointable_utils.Checkpoint(
            optimizer=optimizer, model=model,
            optimizer_step=training_util.get_or_create_global_step())
        root.restore(checkpoint_management.latest_checkpoint(
            checkpoint_directory))

        for _ in range(num_training_steps):
          strategy.extended.call_for_each_replica(
              functools.partial(_train_fn, optimizer, model))
        root.save(file_prefix=checkpoint_prefix)
        self.assertEqual((training_continuation + 1) * num_training_steps,
                         root.optimizer_step.numpy())
开发者ID:AndreasGocht,项目名称:tensorflow,代码行数:29,代码来源:checkpointing_test.py

示例10: _read_vars

 def _read_vars(self, model_dir):
   """Returns (global_step, latest_feature)."""
   with ops.Graph().as_default() as g:
     ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
     meta_filename = ckpt_path + '.meta'
     saver_lib.import_meta_graph(meta_filename)
     saver = saver_lib.Saver()
     with self.test_session(graph=g) as sess:
       saver.restore(sess, ckpt_path)
       return sess.run(ops.get_collection('my_vars'))
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:10,代码来源:iterator_ops_test.py

示例11: testRestoreInReconstructedIterator

 def testRestoreInReconstructedIterator(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
   dataset = Dataset.range(10)
   for i in range(5):
     iterator = datasets.Iterator(dataset)
     checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
     checkpoint.restore(checkpoint_management.latest_checkpoint(
         checkpoint_directory))
     for j in range(2):
       self.assertEqual(i * 2 + j, iterator.get_next().numpy())
     checkpoint.save(file_prefix=checkpoint_prefix)
开发者ID:aeverall,项目名称:tensorflow,代码行数:12,代码来源:datasets_test.py

示例12: testRestoreInReconstructedIteratorInitializable

 def testRestoreInReconstructedIteratorInitializable(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   dataset = dataset_ops.Dataset.range(10)
   iterator = dataset.make_initializable_iterator()
   get_next = iterator.get_next()
   checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
   for i in range(5):
     with self.cached_session() as sess:
       checkpoint.restore(checkpoint_management.latest_checkpoint(
           checkpoint_directory)).initialize_or_restore(sess)
       for j in range(2):
         self.assertEqual(i * 2 + j, self.evaluate(get_next))
       checkpoint.save(file_prefix=checkpoint_prefix)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:14,代码来源:iterator_ops_test.py

示例13: testRestoreInReconstructedIteratorInitializable

 def testRestoreInReconstructedIteratorInitializable(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   dataset = dataset_ops.Dataset.range(10)
   iterator = iter(dataset) if context.executing_eagerly(
   ) else dataset_ops.make_initializable_iterator(dataset)
   get_next = iterator.get_next
   checkpoint = trackable_utils.Checkpoint(iterator=iterator)
   for i in range(5):
     checkpoint.restore(
         checkpoint_management.latest_checkpoint(
             checkpoint_directory)).initialize_or_restore()
     for j in range(2):
       self.assertEqual(i * 2 + j, self.evaluate(get_next()))
     checkpoint.save(file_prefix=checkpoint_prefix)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:15,代码来源:iterator_checkpoint_test.py

示例14: testNameCollision

  def testNameCollision(self):
    # Make sure we have a clean directory to work in.
    with self.tempDir() as tempdir:
      # Jump to that directory until this test is done.
      with self.tempWorkingDir(tempdir):
        # Save training snapshots to a relative path.
        traindir = "train/"
        os.mkdir(traindir)
        # Collides with the default name of the checkpoint state file.
        filepath = os.path.join(traindir, "checkpoint")

        with self.test_session() as sess:
          unused_a = variables.Variable(0.0)  # So that Saver saves something.
          variables.global_variables_initializer().run()

          # Should fail.
          saver = saver_module.Saver(sharded=False)
          with self.assertRaisesRegexp(ValueError, "collides with"):
            saver.save(sess, filepath)

          # Succeeds: the file will be named "checkpoint-<step>".
          saver.save(sess, filepath, global_step=1)
          self.assertIsNotNone(
              checkpoint_management.latest_checkpoint(traindir))

          # Succeeds: the file will be named "checkpoint-<i>-of-<n>".
          saver = saver_module.Saver(sharded=True)
          saver.save(sess, filepath)
          self.assertIsNotNone(
              checkpoint_management.latest_checkpoint(traindir))

          # Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
          saver = saver_module.Saver(sharded=True)
          saver.save(sess, filepath, global_step=1)
          self.assertIsNotNone(
              checkpoint_management.latest_checkpoint(traindir))
开发者ID:clsung,项目名称:tensorflow,代码行数:36,代码来源:checkpoint_management_test.py

示例15: __init__

  def __init__(self,
               estimator,
               prediction_input_fn,
               input_alternative_key=None,
               output_alternative_key=None,
               graph=None,
               config=None):
    """Initialize a `ContribEstimatorPredictor`.

    Args:
      estimator: an instance of `tf.contrib.learn.Estimator`.
      prediction_input_fn: a function that takes no arguments and returns an
        instance of `InputFnOps`.
      input_alternative_key: Optional. Specify the input alternative used for
        prediction.
      output_alternative_key: Specify the output alternative used for
        prediction. Not needed for single-headed models but required for
        multi-headed models.
      graph: Optional. The Tensorflow `graph` in which prediction should be
        done.
      config: `ConfigProto` proto used to configure the session.
    """
    self._graph = graph or ops.Graph()
    with self._graph.as_default():
      input_fn_ops = prediction_input_fn()
      # pylint: disable=protected-access
      model_fn_ops = estimator._get_predict_ops(input_fn_ops.features)
      # pylint: enable=protected-access
      checkpoint_path = checkpoint_management.latest_checkpoint(
          estimator.model_dir)
      self._session = monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              config=config,
              checkpoint_filename_with_path=checkpoint_path))

    input_alternative_key = (
        input_alternative_key or
        saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY)
    input_alternatives, _ = saved_model_export_utils.get_input_alternatives(
        input_fn_ops)
    self._feed_tensors = input_alternatives[input_alternative_key]

    (output_alternatives,
     output_alternative_key) = saved_model_export_utils.get_output_alternatives(
         model_fn_ops, output_alternative_key)
    _, fetch_tensors = output_alternatives[output_alternative_key]
    self._fetch_tensors = fetch_tensors
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:47,代码来源:contrib_estimator_predictor.py


注:本文中的tensorflow.python.training.checkpoint_management.latest_checkpoint函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。