当前位置: 首页>>代码示例>>Python>>正文


Python saver.latest_checkpoint函数代码示例

本文整理汇总了Python中tensorflow.python.training.saver.latest_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python latest_checkpoint函数的具体用法?Python latest_checkpoint怎么用?Python latest_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了latest_checkpoint函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testRecoverSession

  def testRecoverSession(self):
    # Create a checkpoint.
    checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable(1, name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertFalse(initialized)
      sess.run(v.initializer)
      self.assertEquals(1, sess.run(v))
      saver.save(sess,
                 os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
    self._test_recovered_variable(
        checkpoint_filename_with_path=saver_lib.latest_checkpoint(
            checkpoint_dir))
    # Cannot set both checkpoint_dir and checkpoint_filename_with_path.
    with self.assertRaises(ValueError):
      self._test_recovered_variable(
          checkpoint_dir=checkpoint_dir,
          checkpoint_filename_with_path=saver_lib.latest_checkpoint(
              checkpoint_dir))
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:31,代码来源:session_manager_test.py

示例2: _read_config_files

  def _read_config_files(self, run_paths):
    configs = {}
    config_fpaths = {}
    for run_name, logdir in run_paths.items():
      config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
      if not file_io.file_exists(config_fpath):
        # Skip runs that have no config file.
        continue
      # Read the config file.
      file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
      config = ProjectorConfig()
      text_format.Merge(file_content, config)

      if not config.model_checkpoint_path:
        # See if you can find a checkpoint file in the logdir.
        ckpt_path = latest_checkpoint(logdir)
        if not ckpt_path:
          # Or in the parent of logdir.
          ckpt_path = latest_checkpoint(os.path.join('../', logdir))
          if not ckpt_path:
            logging.warning('Cannot find model checkpoint in %s', logdir)
            continue
        config.model_checkpoint_path = ckpt_path

      # Sanity check for the checkpoint file.
      if not file_io.file_exists(config.model_checkpoint_path):
        logging.warning('Checkpoint file %s not found',
                        config.model_checkpoint_path)
        continue
      configs[run_name] = config
      config_fpaths[run_name] = config_fpath
    return configs, config_fpaths
开发者ID:KalraA,项目名称:tensorflow,代码行数:32,代码来源:plugin.py

示例3: _find_latest_checkpoint

def _find_latest_checkpoint(dir_path):
  try:
    ckpt_path = latest_checkpoint(dir_path)
    if not ckpt_path:
      # Check the parent directory.
      ckpt_path = latest_checkpoint(os.path.join(dir_path, os.pardir))
    return ckpt_path
  except errors.NotFoundError:
    return None
开发者ID:chenjun0210,项目名称:tensorflow,代码行数:9,代码来源:projector_plugin.py

示例4: testMultiEvalStepIncrements

  def testMultiEvalStepIncrements(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')

    # Train a model for a single step to get a checkpoint.
    self._train_model(checkpoint_dir, num_steps=1)
    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)

    # Create the model so we have something to restore.
    inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    logistic_classifier(inputs)

    num_evals = 6

    my_var = local_variable(0.0, name='MyVar')
    # In eval ops, we also increase the eval step one more time.
    eval_ops = [state_ops.assign_add(my_var, 1.0),
                state_ops.assign_add(
                    evaluation._get_or_create_eval_step(), 1, use_locking=True)]
    expect_eval_update_counts = num_evals // 2

    final_ops = array_ops.identity(my_var)

    final_ops_values = evaluation._evaluate_once(
        checkpoint_path=checkpoint_path,
        eval_ops=eval_ops,
        final_ops={'value': final_ops},
        hooks=[evaluation._StopAfterNEvalsHook(num_evals),])
    self.assertEqual(final_ops_values['value'], expect_eval_update_counts)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:28,代码来源:evaluation_test.py

示例5: testEvalOpAndFinalOp

  def testEvalOpAndFinalOp(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')

    # Train a model for a single step to get a checkpoint.
    self._train_model(checkpoint_dir, num_steps=1)
    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)

    # Create the model so we have something to restore.
    inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    logistic_classifier(inputs)

    num_evals = 5
    final_increment = 9.0

    my_var = local_variable(0.0, name='MyVar')
    eval_ops = state_ops.assign_add(my_var, 1.0)
    final_ops = array_ops.identity(my_var) + final_increment

    final_hooks = [evaluation._StopAfterNEvalsHook(num_evals),]
    initial_hooks = list(final_hooks)
    final_ops_values = evaluation._evaluate_once(
        checkpoint_path=checkpoint_path,
        eval_ops=eval_ops,
        final_ops={'value': final_ops},
        hooks=final_hooks)
    self.assertEqual(final_ops_values['value'], num_evals + final_increment)
    self.assertEqual(initial_hooks, final_hooks)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:27,代码来源:evaluation_test.py

示例6: testEvaluateWithFiniteInputs

  def testEvaluateWithFiniteInputs(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(),
                                  'evaluate_with_finite_inputs')

    # Train a Model to completion:
    self._train_model(checkpoint_dir, num_steps=300)

    # Run evaluation. Inputs are fed through input producer for one epoch.
    all_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    all_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

    single_input, single_label = training.slice_input_producer(
        [all_inputs, all_labels], num_epochs=1)
    inputs, labels = training.batch([single_input, single_label], batch_size=6,
                                    allow_smaller_final_batch=True)

    logits = logistic_classifier(inputs)
    predictions = math_ops.round(logits)

    accuracy, update_op = metrics.accuracy(
        predictions=predictions, labels=labels)

    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)

    final_ops_values = evaluation._evaluate_once(
        checkpoint_path=checkpoint_path,
        eval_ops=update_op,
        final_ops={'accuracy': accuracy,
                   'eval_steps': evaluation._get_or_create_eval_step()},
        hooks=[evaluation._StopAfterNEvalsHook(None),])
    self.assertTrue(final_ops_values['accuracy'] > .99)
    # Runs evaluation for 4 iterations. First 2 evaluate full batch of 6 inputs
    # each; the 3rd iter evaluates the remaining 4 inputs, and the last one
    # triggers an error which stops evaluation.
    self.assertEqual(final_ops_values['eval_steps'], 4)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:35,代码来源:evaluation_test.py

示例7: test_recovery

 def test_recovery(self):
   logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
   with ops.Graph().as_default():
     gstep = variables_lib.get_or_create_global_step()
     do_step = state_ops.assign_add(gstep, 1)
     scaffold = monitored_session.Scaffold()
     # Use a hook to save the model every 100 steps.  It also saves it at
     # the end.
     hooks = [
         basic_session_run_hooks.CheckpointSaverHook(
             logdir, save_steps=1, scaffold=scaffold)
     ]
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold, checkpoint_dir=logdir),
         hooks=hooks) as session:
       self.assertEqual(0, session.run(gstep))
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
     # A restart will find the checkpoint and recover automatically.
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold, checkpoint_dir=logdir)) as session:
       self.assertEqual(2, session.run(gstep))
     # A restart will find the checkpoint and recover automatically.
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold,
             checkpoint_filename_with_path=saver_lib.latest_checkpoint(
                 logdir))) as session:
       self.assertEqual(2, session.run(gstep))
开发者ID:kadeng,项目名称:tensorflow,代码行数:31,代码来源:monitored_session_test.py

示例8: export_estimator

def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
                     signature_fn=_generic_signature_fn, default_batch_size=1,
                     exports_to_keep=None):
  """Exports inference graph into given dir.

  Args:
    estimator: Estimator to export
    export_dir: A string containing a directory to write the exported graph
      and checkpoints.
    input_fn: Function that given `Tensor` of `Example` strings, parses it into
      features that are then passed to the model.
    signature_fn: Function that given `Tensor` of `Example` strings,
      `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
      and returns default and named exporting signautres.
    default_batch_size: Default batch size of the `Example` placeholder.
    exports_to_keep: Number of exports to keep.
  """
  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)
    examples = array_ops.placeholder(dtype=dtypes.string,
                                     shape=[default_batch_size],
                                     name='input_example_tensor')
    features = input_fn(estimator, examples)
    predictions = estimator._get_predict_ops(features)
    default_signature, named_graph_signatures = signature_fn(
        examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    _export_graph(g, _get_saver(), checkpoint_path, export_dir,
                  default_graph_signature=default_signature,
                  named_graph_signatures=named_graph_signatures,
                  exports_to_keep=exports_to_keep)
开发者ID:363158858,项目名称:tensorflow,代码行数:33,代码来源:export.py

示例9: _restore_or_save_initial_ckpt

  def _restore_or_save_initial_ckpt(self, session):
    # Ideally this should be run in after_create_session but is not for the
    # following reason:
    # Currently there is no way of enforcing an order of running the
    # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
    # is run *after* this hook. That is troublesome because
    # 1. If a checkpoint exists and this hook restores it, the initializer hook
    #    will override it.
    # 2. If no checkpoint exists, this hook will try to save an initialized
    #    iterator which will result in an exception.
    #
    # As a temporary fix we enter the following implicit contract between this
    # hook and the _DatasetInitializerHook.
    # 1. The _DatasetInitializerHook initializes the iterator in the call to
    #    after_create_session.
    # 2. This hook saves the iterator on the first call to `before_run()`, which
    #    is guaranteed to happen after `after_create_session()` of all hooks
    #    have been run.

    # Check if there is an existing checkpoint. If so, restore from it.
    # pylint: disable=protected-access
    latest_checkpoint_path = saver_lib.latest_checkpoint(
        self._checkpoint_saver_hook._checkpoint_dir,
        latest_filename=self._latest_filename)
    if latest_checkpoint_path:
      self._checkpoint_saver_hook._get_saver().restore(session,
                                                       latest_checkpoint_path)
    else:
      # The checkpoint saved here is the state at step "global_step".
      # Note: We do not save the GraphDef or MetaGraphDef here.
      global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
      self._checkpoint_saver_hook._save(session, global_step)
      self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:33,代码来源:iterator_ops.py

示例10: _evaluate_model

  def _evaluate_model(self,
                      input_fn,
                      steps,
                      feed_fn=None,
                      metrics=None,
                      name=''):
    if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
      return

    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    eval_dir = os.path.join(self._model_dir, 'eval' if not name else
                            'eval_' + name)
    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      eval_dict = self._get_eval_ops(features, targets,
                                     metrics if metrics is not None else
                                     self._get_default_metric_functions())
      update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
      eval_results, _ = evaluate(graph=g,
                                 output_dir=eval_dir,
                                 checkpoint_path=checkpoint_path,
                                 eval_dict=eval_dict,
                                 update_op=update_op,
                                 global_step_tensor=global_step,
                                 supervisor_master=self._config.master,
                                 feed_fn=feed_fn,
                                 max_steps=steps)
      return eval_results
开发者ID:01-,项目名称:tensorflow,代码行数:31,代码来源:estimator.py

示例11: _infer_model

  def _infer_model(
      self, input_fn, feed_fn=None, outputs=None, as_iterable=False):
    # Check that model has been trained.
    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    if not checkpoint_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % self._model_dir)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)
      predictions = self._get_predict_ops(features)
      # If predictions is single output - wrap it into dict, and remember to
      # return not a dict.
      return_dict = isinstance(predictions, dict)
      if not return_dict:
        predictions = {'predictions': predictions}

      # Filter what to run predictions on, if outputs provided.
      if outputs:
        existing_keys = predictions.keys()
        predictions = {
            key: value for key, value in predictions.items() if key in outputs
        }
        if not predictions:
          raise ValueError('Expected to run at least one output from %s, '
                           'provided %s.' % (existing_keys, outputs))

      if as_iterable:
        return self._infer_model_as_iterable(
            checkpoint_path, predictions, feed_fn, return_dict)
      else:
        return self._infer_model_single(
            checkpoint_path, predictions, feed_fn, return_dict)
开发者ID:Nishant23,项目名称:tensorflow,代码行数:35,代码来源:estimator.py

示例12: _infer_model

  def _infer_model(self,
                   x=None, input_fn=None, feed_fn=None,
                   batch_size=None, axis=None, proba=False):
    # Converts inputs into tf.DataFrame / tf.Series.
    batch_size = -1 if batch_size is None else batch_size
    if x is not None:
      input_fn, feed_fn = _get_predict_input_fn(x, batch_size)

    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features, _ = input_fn()
      predictions = self._get_predict_ops(features)
      if not isinstance(predictions, dict):
        predictions = {'predictions': predictions}
      # TODO(ipolosukhin): Support batching
      if feed_fn is None:
        return infer(checkpoint_path, predictions)
      preds = {}
      while True:
        try:
          feed_dict = feed_fn()
        except StopIteration:
          break
        if feed_dict is None:
          break
        outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
        for key in outputs:
          if key not in preds:
            preds[key] = []
          preds[key].append(outputs[key])
      for key in preds:
        preds[key] = np.concatenate(preds[key], axis=0)
      return preds
开发者ID:01-,项目名称:tensorflow,代码行数:35,代码来源:estimator.py

示例13: testUsageGraph

 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         network = MyNetwork()
         optimizer = adam.AdamOptimizer(0.001)
         root = checkpointable_utils.Checkpoint(
             optimizer=optimizer, network=network,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             network(input_value),
             global_step=root.global_step)
         checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
         with self.test_session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:34,代码来源:checkpointable_utils_test.py

示例14: create_session

 def create_session(self, checkpoint_dir):
     """Creates a MonitoredSession for this predictor."""
     checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
     return training.MonitoredSession(
         session_creator=training.ChiefSessionCreator(
             checkpoint_filename_with_path=checkpoint_path,
             config=self._session_config()))
开发者ID:ucam-smt,项目名称:sgnmt,代码行数:7,代码来源:tf_nizza.py

示例15: testDeferredRestorationUsageEager

 def testDeferredRestorationUsageEager(self):
   """An idiomatic eager execution example."""
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   latest_object_graph = None  # Will be saved with the checkpoint eventually.
   for training_continuation in range(3):
     with ops.Graph().as_default():
       network = MyNetwork()
       optimizer = CheckpointableAdam(0.001)
       root = Root(optimizer=optimizer, network=network)
       checkpointable.restore(
           save_path=core_saver.latest_checkpoint(checkpoint_directory),
           root_checkpointable=root,
           object_graph_proto=latest_object_graph)
       for _ in range(num_training_steps):
         # TODO(allenl): Use a Dataset and serialize/checkpoint it.
         input_value = constant_op.constant([[3.]])
         optimizer.minimize(
             lambda: network(input_value),  # pylint: disable=cell-var-from-loop
             global_step=root.global_step)
       latest_object_graph, _ = checkpointable.save(
           file_prefix=checkpoint_prefix,
           root_checkpointable=root)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        root.global_step.numpy())
开发者ID:japrogramer,项目名称:tensorflow,代码行数:26,代码来源:checkpointable_test.py


注:本文中的tensorflow.python.training.saver.latest_checkpoint函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。