当前位置: 首页>>代码示例>>Python>>正文


Python variables.get_or_create_global_step函数代码示例

本文整理汇总了Python中tensorflow.contrib.framework.python.ops.variables.get_or_create_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python get_or_create_global_step函数的具体用法?Python get_or_create_global_step怎么用?Python get_or_create_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_or_create_global_step函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: testEvaluationLoopTimeout

  def testEvaluationLoopTimeout(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(),
                                  'evaluation_loop_timeout')
    if not gfile.Exists(checkpoint_dir):
      gfile.MakeDirs(checkpoint_dir)

    # We need a variable that that the saver will try to restore.
    variables.get_or_create_global_step()

    # Run with placeholders. If we actually try to evaluate this, we'd fail
    # since we're not using a feed_dict.
    cant_run_op = array_ops.placeholder(dtype=dtypes.float32)

    start = time.time()
    final_values = evaluation.evaluate_repeatedly(
        checkpoint_dir=checkpoint_dir,
        eval_ops=cant_run_op,
        hooks=[evaluation.StopAfterNEvalsHook(10)],
        timeout=6)
    end = time.time()
    self.assertFalse(final_values)

    # Assert that we've waited for the duration of the timeout (minus the sleep
    # time).
    self.assertGreater(end - start, 5.0)

    # Then the timeout kicked in and stops the loop.
    self.assertLess(end - start, 7)
开发者ID:Immexxx,项目名称:tensorflow,代码行数:28,代码来源:evaluation_test.py

示例2: testEvaluateWithEvalFeedDict

  def testEvaluateWithEvalFeedDict(self):
    # Create a checkpoint.
    checkpoint_dir = os.path.join(self.get_temp_dir(),
                                  'evaluate_with_eval_feed_dict')
    self._train_model(checkpoint_dir, num_steps=1)

    # We need a variable that that the saver will try to restore.
    variables.get_or_create_global_step()

    # Create a variable and an eval op that increments it with a placeholder.
    my_var = variables.local_variable(0.0, name='my_var')
    increment = array_ops.placeholder(dtype=dtypes.float32)
    eval_ops = state_ops.assign_add(my_var, increment)

    increment_value = 3
    num_evals = 5
    expected_value = increment_value * num_evals
    final_values = evaluation.evaluate_repeatedly(
        checkpoint_dir=checkpoint_dir,
        eval_ops=eval_ops,
        feed_dict={increment: 3},
        final_ops={'my_var': array_ops.identity(my_var)},
        hooks=[evaluation.StopAfterNEvalsHook(num_evals),],
        max_number_of_evaluations=1)
    self.assertEqual(final_values['my_var'], expected_value)
开发者ID:Immexxx,项目名称:tensorflow,代码行数:25,代码来源:evaluation_test.py

示例3: test_two_listeners_with_default_saver

  def test_two_listeners_with_default_saver(self):
    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      train_op = state_ops.assign_add(global_step, 1)
      listener1 = MockCheckpointSaverListener()
      listener2 = MockCheckpointSaverListener()
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir,
          save_steps=1,
          listeners=[listener1, listener2])
      with monitored_session.SingularMonitoredSession(
          hooks=[hook],
          checkpoint_dir=self.model_dir) as sess:
        sess.run(train_op)
        sess.run(train_op)
        global_step_val = sess.run(global_step)
      listener1_counts = listener1.get_counts()
      listener2_counts = listener2.get_counts()
    self.assertEqual(2, global_step_val)
    self.assertEqual({
        'begin': 1,
        'before_save': 2,
        'after_save': 2,
        'end': 1
    }, listener1_counts)
    self.assertEqual(listener1_counts, listener2_counts)

    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      with monitored_session.SingularMonitoredSession(
          checkpoint_dir=self.model_dir) as sess2:
        global_step_saved_val = sess2.run(global_step)
    self.assertEqual(2, global_step_saved_val)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:33,代码来源:basic_session_run_hooks_test.py

示例4: test_step_counter_every_n_secs

  def test_step_counter_every_n_secs(self):
    with ops.Graph().as_default() as g, session_lib.Session() as sess:
      variables.get_or_create_global_step()
      train_op = training_util._increment_global_step(1)
      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
      hook = basic_session_run_hooks.StepCounterHook(
          summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)

      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      hook.end(sess)

      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.log_dir,
          expected_graph=g,
          expected_summaries={})
      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
      self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
      for summary in summary_writer.summaries.values():
        summary_value = summary[0].value[0]
        self.assertEqual('global_step/sec', summary_value.tag)
        self.assertGreater(summary_value.simple_value, 0)
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:29,代码来源:basic_session_run_hooks_test.py

示例5: test_step_counter_every_n_steps

 def test_step_counter_every_n_steps(self):
   with ops.Graph().as_default() as g, session_lib.Session() as sess:
     variables.get_or_create_global_step()
     train_op = training_util._increment_global_step(1)
     summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
     hook = basic_session_run_hooks.StepCounterHook(
         summary_writer=summary_writer, every_n_steps=10)
     hook.begin()
     sess.run(variables_lib.global_variables_initializer())
     mon_sess = monitored_session._HookedSession(sess, [hook])
     with test.mock.patch.object(tf_logging, 'warning') as mock_log:
       for _ in range(30):
         time.sleep(0.01)
         mon_sess.run(train_op)
       # logging.warning should not be called.
       self.assertIsNone(mock_log.call_args)
     hook.end(sess)
     summary_writer.assert_summaries(
         test_case=self,
         expected_logdir=self.log_dir,
         expected_graph=g,
         expected_summaries={})
     self.assertItemsEqual([11, 21], summary_writer.summaries.keys())
     for step in [11, 21]:
       summary_value = summary_writer.summaries[step][0].value[0]
       self.assertEqual('global_step/sec', summary_value.tag)
       self.assertGreater(summary_value.simple_value, 0)
开发者ID:moses-sun,项目名称:tensorflow,代码行数:27,代码来源:basic_session_run_hooks_test.py

示例6: test_not_wait_for_step_zero

 def test_not_wait_for_step_zero(self):
   with ops.Graph().as_default():
     variables.get_or_create_global_step()
     hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)
     hook.begin()
     with session_lib.Session() as sess:
       # Before run should return without waiting gstep increment.
       hook.before_run(
           session_run_hook.SessionRunContext(
               original_args=None, session=sess))
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:10,代码来源:basic_session_run_hooks_test.py

示例7: setUp

  def setUp(self):
    test.TestCase.setUp(self)

    self.log_dir = 'log/dir'
    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)

    var = variables_lib.Variable(0.0)
    tensor = state_ops.assign_add(var, 1.0)
    tensor2 = tensor * 2
    self.summary_op = summary_lib.scalar('my_summary', tensor)
    self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)

    variables.get_or_create_global_step()
    self.train_op = training_util._increment_global_step(1)
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:14,代码来源:basic_session_run_hooks_test.py

示例8: test_recover_and_retry_on_aborted_error

 def test_recover_and_retry_on_aborted_error(self):
   # Tests that we silently retry and recover on abort.  This test uses
   # a CheckpointSaver to have something to recover from.
   logdir = _test_dir(self.get_temp_dir(),
                      'test_recover_and_retry_on_aborted_error')
   with ops.Graph().as_default():
     gstep = variables_lib.get_or_create_global_step()
     do_step = state_ops.assign_add(gstep, 1)
     scaffold = monitored_session.Scaffold()
     abort_hook = RaiseOnceAtCountN(
         4, errors_impl.AbortedError(None, None, 'Abort'))
     # Save after each step.
     ckpt_hook = basic_session_run_hooks.CheckpointSaverHook(
         logdir, save_steps=1, scaffold=scaffold)
     hooks = [abort_hook, ckpt_hook]
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold, checkpoint_dir=logdir),
         hooks=hooks) as session:
       self.assertEqual(0, session.run(gstep))
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
       self.assertFalse(session.should_stop())
       # Here at step 3, the hook triggers and raises AbortedError.  The
       # MonitoredSession automatically restores and retries.
       self.assertEqual(3, session.run(do_step))
       self.assertTrue(abort_hook.raised)
       self.assertFalse(session.should_stop())
       self.assertEqual(4, session.run(do_step))
       self.assertFalse(session.should_stop())
开发者ID:kadeng,项目名称:tensorflow,代码行数:30,代码来源:monitored_session_test.py

示例9: test_recovery

 def test_recovery(self):
   logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
   with ops.Graph().as_default():
     gstep = variables_lib.get_or_create_global_step()
     do_step = state_ops.assign_add(gstep, 1)
     scaffold = monitored_session.Scaffold()
     # Use a hook to save the model every 100 steps.  It also saves it at
     # the end.
     hooks = [
         basic_session_run_hooks.CheckpointSaverHook(
             logdir, save_steps=1, scaffold=scaffold)
     ]
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold, checkpoint_dir=logdir),
         hooks=hooks) as session:
       self.assertEqual(0, session.run(gstep))
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
     # A restart will find the checkpoint and recover automatically.
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold, checkpoint_dir=logdir)) as session:
       self.assertEqual(2, session.run(gstep))
     # A restart will find the checkpoint and recover automatically.
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold,
             checkpoint_filename_with_path=saver_lib.latest_checkpoint(
                 logdir))) as session:
       self.assertEqual(2, session.run(gstep))
开发者ID:kadeng,项目名称:tensorflow,代码行数:31,代码来源:monitored_session_test.py

示例10: test_num_steps

  def test_num_steps(self):
    logdir = _test_dir(self.get_temp_dir(), 'test_num_steps')
    with ops.Graph().as_default():
      gstep = variables_lib.get_or_create_global_step()
      do_step = state_ops.assign_add(gstep, 1)
      # Do 3 steps and save.
      hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)]
      scaffold = monitored_session.Scaffold().finalize()
      with monitored_session.MonitoredSession(hooks=hooks) as session:
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertTrue(session.should_stop())
        save_path = scaffold.saver.save(session._coordinated_creator.tf_sess,
                                        os.path.join(logdir, 'step-3'))
      # Restore and do 4 steps.
      def load_ckpt(scaffold, sess):
        scaffold.saver.restore(sess, save_path)

      session_creator = monitored_session.ChiefSessionCreator(
          scaffold=monitored_session.Scaffold(init_fn=load_ckpt))
      hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=4)]
      with monitored_session.MonitoredSession(
          hooks=hooks, session_creator=session_creator) as session:
        self.assertEqual(4, session.run(do_step))
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertFalse(session.should_stop())
        session.run(do_step)
        self.assertTrue(session.should_stop())
开发者ID:kadeng,项目名称:tensorflow,代码行数:34,代码来源:monitored_session_test.py

示例11: __init__

    def __init__(
        self,
        global_step_tensor=None,
        init_op=None,
        init_feed_dict=None,
        init_fn=None,
        ready_op=None,
        local_init_op=None,
        summary_op=None,
        saver=None,
        keep_checkpoint_max=5,
    ):
        """Create a scaffold.

    Args:
      global_step_tensor: Optional tensor to use as the global step counter.
      init_op: Optional op for initializing variables.
      init_feed_dict: Optional session feed dictionary to use when running the
        init_op.
      init_fn: Optional function to use to initialize the model after running
        the init_op.  Will be called as `init_fn(scaffold, session)`.
      ready_op: Optional op to verify that the variables are initialized.  Must
        return an empty scalar string tensor when the variables are
        initialized, or a non-empty one listing the names of the
        non-initialized variables.
      local_init_op: Optional op to initialize local variables.
      summary_op: Optional op to gather all summaries.  Must return a scalar
        string tensor containing a serialized `Summary` proto.
      saver: Optional `tf.Saver` object to use to save and restore variables.
      keep_checkpoint_max: Optional parameter to use to construct a saver if
        none is already there in the graph.
    """
        if global_step_tensor is None:
            global_step_tensor = contrib_variables.get_or_create_global_step()
        self.global_step_tensor = global_step_tensor
        if init_op is None:
            init_op = Scaffold._get_or_default(ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
        self.init_op = init_op
        self.init_feed_dict = init_feed_dict
        # NOTE(touts): modifying the init function to be passed the scaffold is a
        # hack to make it easy to find the saver.  Is there a better way?
        if init_fn:
            self.init_fn = lambda sess: init_fn(self, sess)
        else:
            self.init_fn = None
        if ready_op is None:
            ready_op = Scaffold._get_or_default(ops.GraphKeys.READY_OP, variables.report_uninitialized_variables)
        self.ready_op = ready_op
        if local_init_op is None:
            local_init_op = Scaffold._get_or_default(ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op)
        self.local_init_op = local_init_op
        if summary_op is None:
            summary_op = Scaffold._get_or_default(ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries)
        # pylint: disable=g-long-lambda
        if saver is None:
            saver = Scaffold._get_or_default(
                ops.GraphKeys.SAVERS, lambda: training_saver.Saver(sharded=True, max_to_keep=keep_checkpoint_max)
            )
        # pylint: enable=g-long-lambda
        self.saver = saver
开发者ID:sathishreddy,项目名称:tensorflow,代码行数:60,代码来源:supervised_session.py

示例12: _get_train_ops

  def _get_train_ops(self, features, targets):
    """See base class."""
    if not isinstance(self._linear_optimizer, sdca_optimizer.SDCAOptimizer):
      return super(LinearRegressor, self)._get_train_ops(features, targets)
    assert not self._joint_weights, ("_joint_weights is incompatible with"
                                     " SDCAOptimizer.")
    global_step = contrib_variables.get_or_create_global_step()

    logits, columns_to_variables, bias = (
        layers.weighted_sum_from_feature_columns(
            columns_to_tensors=features,
            feature_columns=self._linear_feature_columns,
            num_outputs=self._target_column.num_label_columns,
            weight_collections=[self._linear_model.get_scope_name()],
            scope=self._linear_model.get_scope_name()))
    with ops.control_dependencies([self._centered_bias()]):
      loss = self._target_column.loss(logits, targets, features)
      logging_ops.scalar_summary("loss", loss)

      _add_bias_column(self._linear_feature_columns, features, bias, targets,
                       columns_to_variables)

    train_op = self._linear_optimizer.get_train_step(
        columns_to_variables, self._target_column.weight_column_name,
        self._loss_type(), features, targets, global_step)
    return train_op, loss
开发者ID:KalraA,项目名称:tensorflow,代码行数:26,代码来源:linear.py

示例13: __init__

  def __init__(self,
               log_dir=None,
               summary_writer=None,
               summary_op=None,
               feed_dict=None):
    """Constructs the Summary Hook.

    Args:
      log_dir: The directory where the summary events are saved to.  Used only
        when `summary_writer` is not specified.
      summary_writer: A `tf.summary.FileWriter` to write summary events with.
      summary_op: The summary op to run. If left as `None`, then all summaries
        in the tf.GraphKeys.SUMMARIES collection are used.
      feed_dict: An optional feed dictionary to use when evaluating the
        summaries.

    Raises:
      ValueError: If both `log_dir` and `summary_writer` are `None`.
    """
    self._summary_op = summary_op
    self._feed_dict = feed_dict
    self._summary_writer = summary_writer
    self._log_dir = log_dir
    self._summary_writer = summary_writer
    if self._log_dir is None and self._summary_writer is None:
      raise ValueError('One of log_dir or summary_writer should be used.')
    self._global_step = variables.get_or_create_global_step()
开发者ID:LUTAN,项目名称:tensorflow,代码行数:27,代码来源:evaluation.py

示例14: finalize

  def finalize(self):
    """Creates operations if needed and finalizes the graph."""
    if self._global_step_tensor is None:
      self._global_step_tensor = contrib_variables.get_or_create_global_step()
    if self._init_op is None:
      self._init_op = Scaffold._get_or_default(
          'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
    if self._ready_op is None:
      self._ready_op = Scaffold._get_or_default(
          'ready_op', ops.GraphKeys.READY_OP,
          variables.report_uninitialized_variables)
    if self._local_init_op is None:
      self._local_init_op = Scaffold._get_or_default(
          'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
          Scaffold._default_local_init_op)
    if self._summary_op is None:
      self._summary_op = Scaffold._get_or_default(
          'summary_op', ops.GraphKeys.SUMMARY_OP,
          logging_ops.merge_all_summaries)
    # pylint: disable=g-long-lambda
    if self._saver is None:
      self._saver = Scaffold._get_or_default(
          'saver',
          ops.GraphKeys.SAVERS,
          lambda: training_saver.Saver(sharded=True,
                                       max_to_keep=self._keep_checkpoint_max))
    # pylint: enable=g-long-lambda

    ops.get_default_graph().finalize()
开发者ID:10imaging,项目名称:tensorflow,代码行数:29,代码来源:supervised_session.py

示例15: testNoneGlobalStep

  def testNoneGlobalStep(self):
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
      tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

      tf_predictions = batchnorm_classifier(tf_inputs)
      loss_ops.log_loss(tf_predictions, tf_labels)
      total_loss = loss_ops.get_total_loss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(
          total_loss, optimizer, global_step=None)

      global_step = variables_lib.get_or_create_global_step()

      with session_lib.Session() as sess:
        # Initialize all variables
        sess.run(variables_lib2.global_variables_initializer())

        for _ in range(10):
          sess.run([train_op])
        global_step = global_step.eval()
        # Since train_op don't use global_step it shouldn't change.
        self.assertAllClose(global_step, 0)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:25,代码来源:training_test.py


注:本文中的tensorflow.contrib.framework.python.ops.variables.get_or_create_global_step函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。