当前位置: 首页>>代码示例>>Python>>正文


Python training_util._get_or_create_global_step_read函数代码示例

本文整理汇总了Python中tensorflow.python.training.training_util._get_or_create_global_step_read函数的典型用法代码示例。如果您正苦于以下问题:Python _get_or_create_global_step_read函数的具体用法?Python _get_or_create_global_step_read怎么用?Python _get_or_create_global_step_read使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了_get_or_create_global_step_read函数的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: begin

 def begin(self):
   if self._summary_writer is None and self._output_dir:
     self._summary_writer = SummaryWriterCache.get(self._output_dir)
   self._next_step = None
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use SummarySaverHook.")
开发者ID:didukhle,项目名称:tensorflow,代码行数:8,代码来源:basic_session_run_hooks.py

示例2: begin

 def begin(self):
   self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use CheckpointSaverHook.")
   for l in self._listeners:
     l.begin()
开发者ID:becster,项目名称:tensorflow,代码行数:8,代码来源:async_checkpoint.py

示例3: test_reads_before_increments

 def test_reads_before_increments(self):
   with ops.Graph().as_default():
     training_util.create_global_step()
     read_tensor = training_util._get_or_create_global_step_read()
     inc_op = training_util._increment_global_step(1)
     inc_three_op = training_util._increment_global_step(3)
     with monitored_session.MonitoredTrainingSession() as sess:
       read_value, _ = sess.run([read_tensor, inc_op])
       self.assertEqual(0, read_value)
       read_value, _ = sess.run([read_tensor, inc_three_op])
       self.assertEqual(1, read_value)
       read_value = sess.run(read_tensor)
       self.assertEqual(4, read_value)
开发者ID:aeverall,项目名称:tensorflow,代码行数:13,代码来源:training_util_test.py

示例4: _train_model

  def _train_model(self, input_fn, hooks, saving_listeners):
    worker_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)
      global_step_read_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
      with ops.control_dependencies([global_step_read_tensor]):
        features, labels = self._get_features_and_labels_from_input_fn(
            input_fn, model_fn_lib.ModeKeys.TRAIN)
      estimator_spec = self._call_model_fn(
          features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
      # Check if the user created a loss summary, and add one if they didn't.
      # We assume here that the summary is called 'loss'. If it is not, we will
      # make another one with the name 'loss' to ensure it shows up in the right
      # graph in TensorBoard.
      if not any([x.op.name == 'loss'
                  for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
        summary.scalar('loss', estimator_spec.loss)
      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
      worker_hooks.extend(hooks)
      worker_hooks.extend([
          training.NanTensorHook(estimator_spec.loss),
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=100)
      ])
      worker_hooks.extend(estimator_spec.training_hooks)

      if not (estimator_spec.scaffold.saver or
              ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(
            ops.GraphKeys.SAVERS,
            training.Saver(
                sharded=True,
                max_to_keep=self._config.keep_checkpoint_max,
                keep_checkpoint_every_n_hours=(
                    self._config.keep_checkpoint_every_n_hours),
                defer_build=True,
                save_relative_paths=True))

      chief_hooks = []
      all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
      saver_hooks = [
          h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        if not saver_hooks:
          chief_hooks = [
              training.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=estimator_spec.scaffold)
          ]
          saver_hooks = [chief_hooks[0]]
      if saving_listeners:
        if not saver_hooks:
          raise ValueError(
              'There should be a CheckpointSaverHook to use saving_listeners. '
              'Please set one of the RunConfig.save_checkpoints_steps or '
              'RunConfig.save_checkpoints_secs.')
        else:
          # It is expected to have one CheckpointSaverHook. If multiple, we pick
          # up the first one to add listener.
          saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access
      with training.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=estimator_spec.scaffold,
          hooks=worker_hooks,
          chief_only_hooks=(
              tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=self._session_config,
          log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      return loss
开发者ID:ilya-edrenkin,项目名称:tensorflow,代码行数:84,代码来源:estimator.py

示例5: begin

 def begin(self):
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError(
         'Global step should be created to use StopAtCheckpointStepHook.')
开发者ID:Jordan1237,项目名称:tensorflow,代码行数:5,代码来源:hooks.py

示例6: test_global_step_read_is_none_if_there_is_no_global_step

 def test_global_step_read_is_none_if_there_is_no_global_step(self):
   with ops.Graph().as_default():
     self.assertIsNone(training_util._get_or_create_global_step_read())
     training_util.create_global_step()
     self.assertIsNotNone(training_util._get_or_create_global_step_read())
开发者ID:aeverall,项目名称:tensorflow,代码行数:5,代码来源:training_util_test.py

示例7: test_reads_from_cache

 def test_reads_from_cache(self):
   with ops.Graph().as_default():
     training_util.create_global_step()
     first = training_util._get_or_create_global_step_read()
     second = training_util._get_or_create_global_step_read()
     self.assertEqual(first, second)
开发者ID:aeverall,项目名称:tensorflow,代码行数:6,代码来源:training_util_test.py


注:本文中的tensorflow.python.training.training_util._get_or_create_global_step_read函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。