当前位置: 首页>>代码示例>>Python>>正文


Python training_util.get_global_step函数代码示例

本文整理汇总了Python中tensorflow.python.training.training_util.get_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python get_global_step函数的具体用法?Python get_global_step怎么用?Python get_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_global_step函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_get_global_step

 def test_get_global_step(self):
   with ops.Graph().as_default() as g:
     self.assertIsNone(training_util.get_global_step())
     variables.VariableV1(
         0,
         trainable=False,
         dtype=dtypes.int32,
         name=ops.GraphKeys.GLOBAL_STEP)
     self._assert_global_step(
         training_util.get_global_step(), expected_dtype=dtypes.int32)
   self._assert_global_step(
       training_util.get_global_step(g), expected_dtype=dtypes.int32)
开发者ID:aeverall,项目名称:tensorflow,代码行数:12,代码来源:training_util_test.py

示例2: _ModelFn

    def _ModelFn(features, labels, mode):
      if is_training:
        logits_out = self._BuildGraph(features)
      else:
        graph_def = self._GetGraphDef(use_trt, batch_size, model_dir)
        logits_out = importer.import_graph_def(
            graph_def,
            input_map={INPUT_NODE_NAME: features},
            return_elements=[OUTPUT_NODE_NAME + ':0'],
            name='')[0]

      loss = losses.sparse_softmax_cross_entropy(
          labels=labels, logits=logits_out)
      summary.scalar('loss', loss)

      classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out')
      accuracy = metrics.accuracy(
          labels=labels, predictions=classes_out, name='acc_op')
      summary.scalar('accuracy', accuracy[1])

      if mode == ModeKeys.EVAL:
        return EstimatorSpec(
            mode, loss=loss, eval_metric_ops={'accuracy': accuracy})
      elif mode == ModeKeys.TRAIN:
        optimizer = AdamOptimizer(learning_rate=1e-2)
        train_op = optimizer.minimize(loss, global_step=get_global_step())
        return EstimatorSpec(mode, loss=loss, train_op=train_op)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:27,代码来源:quantization_mnist_test.py

示例3: _train_op_fn

  def _train_op_fn(loss):
    """Returns the op to optimize the loss."""
    train_ops = []
    global_step = training_util.get_global_step()
    if dnn_logits is not None:
      train_ops.append(
          dnn_optimizer.minimize(
              loss,
              var_list=ops.get_collection(
                  ops.GraphKeys.TRAINABLE_VARIABLES,
                  scope=dnn_parent_scope)))
    if linear_logits is not None:
      train_ops.append(
          linear_optimizer.minimize(
              loss,
              var_list=ops.get_collection(
                  ops.GraphKeys.TRAINABLE_VARIABLES,
                  scope=linear_parent_scope)))

    train_op = control_flow_ops.group(*train_ops)
    with ops.control_dependencies([train_op]):
      with ops.colocate_with(global_step):
        return state_ops.assign_add(global_step, 1)

    return head.create_estimator_spec(
        features=features,
        mode=mode,
        labels=labels,
        train_op_fn=_train_op_fn,
        logits=logits)
开发者ID:m-colombo,项目名称:tensorflow,代码行数:30,代码来源:dnn_linear_combined.py

示例4: record_summaries_every_n_global_steps

def record_summaries_every_n_global_steps(n):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  collection_ref[:] = [training_util.get_global_step() % n == 0]
  yield
  collection_ref[:] = old
开发者ID:benoitsteiner,项目名称:tensorflow-opencl,代码行数:7,代码来源:summary_ops.py

示例5: before_run

 def before_run(self, run_context):
   loss = (self.loss_op if self.loss_op is not None else
           run_context.session.graph.get_operation_by_name(
               LOSS_NAME).outputs[0])
   return session_run_hook.SessionRunArgs(
       {'global_step': training_util.get_global_step(),
        'current_loss': loss})
开发者ID:AnishShah,项目名称:tensorflow,代码行数:7,代码来源:random_forest.py

示例6: __init__

 def __init__(self):
   global_step = training_util.get_global_step()
   if global_step:
     self._global_step_incr_op = state_ops.assign_add(
         global_step, 1, name="global_step_incr").op
   else:
     self._global_step_incr_op = None
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:7,代码来源:wals.py

示例7: begin

 def begin(self):
   self._last_reported_time = None
   self._last_reported_step = None
   self._global_step_tensor = training_util.get_global_step()
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use StepCounterHook.")
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:7,代码来源:basic_session_run_hooks.py

示例8: function

 def function(tag, scope):
   if bad_color is None:
     bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
   gen_summary_ops.write_image_summary(
       context.context().summary_writer_resource,
       training_util.get_global_step(), tag, tensor, bad_color_, max_images,
       name=scope)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:7,代码来源:summary_ops.py

示例9: get_updates

  def get_updates(self, loss, params):
    if distribute_lib.has_distribution_strategy():
      self.updates = []

      if not params:
        # After the model vars have been created, the second call to get_updates
        # is called with params as an empty list. This ensures that we call
        # compute_gradients with params=None.
        grads = self.optimizer.compute_gradients(loss)
      else:
        grads = self.optimizer.compute_gradients(loss, params)
      global_step = training_util.get_global_step()
      opt_update = self.optimizer.apply_gradients(grads, global_step)
    else:
      if not params:
        self.updates = [state_ops.assign_add(self.iterations, 1)]
        return self.updates

      # Updates list starts out empty because the iterations variable is
      # incremented in optimizer.apply_gradients()
      self.updates = []
      grads = self.optimizer.compute_gradients(loss, params)
      opt_update = self.optimizer.apply_gradients(
          grads, global_step=self.iterations)

    self.updates.append(opt_update)
    return self.updates
开发者ID:sonnyhu,项目名称:tensorflow,代码行数:27,代码来源:optimizers.py

示例10: __init__

    def __init__(self,
                 checkpoint_dir,
                 display_steps=100,
                 maximum_train_steps=None,
                 do_summary=True,
                 is_chief=True):
        """ Initializes the hook.

        Args:
            checkpoint_dir: A string, base directory for the checkpoint files.
            display_steps: A python integer, display every N steps.
            maximum_train_steps: A python integer, the maximum training steps.
            do_summary: Whether to save summaries when display.
            is_chief: Whether this is the chief process.do_summary:
        """

        tf.logging.info("Create DisplayHook.")
        self._checkpoint_dir = checkpoint_dir
        # display steps
        self._display_steps = display_steps
        self._maximum_train_steps = maximum_train_steps
        self._do_summary = do_summary
        self._is_chief = is_chief  # not used now

        # display values
        global_step = training_util.get_global_step()
        display_keys = ops.get_collection(Constants.DISPLAY_KEY_COLLECTION_NAME)
        display_values = ops.get_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME)
        self._display_args = dict(zip(display_keys, display_values))
        self._display_args["global_step"] = global_step
        # timer & summary writer
        self._timer = None
        self._logging_timer = None
        self._summary_writer = None
开发者ID:KIngpon,项目名称:NJUNMT-tf,代码行数:34,代码来源:hooks.py

示例11: begin

 def begin(self):
   self._global_step_tensor = training_util.get_global_step()
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use CheckpointSaverHook.")
   for l in self._listeners:
     l.begin()
开发者ID:kadeng,项目名称:tensorflow,代码行数:7,代码来源:basic_session_run_hooks.py

示例12: begin

 def begin(self):
   self._last_saved_step = None
   self._request_summary = True
   self._global_step_tensor = training_util.get_global_step()
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use SummarySaverHook.")
开发者ID:KalraA,项目名称:tensorflow,代码行数:7,代码来源:basic_session_run_hooks.py

示例13: after_create_session

  def after_create_session(self, training_session, coord):  # pylint: disable=unused-argument
    # N.B. We have to pull the global step here to avoid it being unavailable
    # at checkpoint time; the graph has been frozen at that point.
    if training_util.get_global_step() is None and self.saver() is not None:
      raise ValueError(
          'Saver defined but no global step.  Run `get_or_create_global_step()`'
          ' in your model definition to allow checkpointing.')

    with self._graph.as_default():
      logging.info('Installing graceful shutdown hook.')
      self._session = _clone_session(training_session, self._graph)
      self._workers = WorkerHeartbeatManager.from_devices(
          self._session, all_worker_devices(self._session))
      self._heartbeat_supported = self._workers.num_workers() > 0
      if self._heartbeat_supported:
        try:
          self._workers.configure(
              event_pb2.WorkerHeartbeatRequest(
                  shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
        except errors.InvalidArgumentError:
          logging.warn(
              'TPU device does not support heartbeats. Failure '
              'handling will be disabled.')
          self._heartbeat_supported = False
      else:
        logging.warn(
            'No workers support hearbeats. Failure handling will be disabled.')
开发者ID:aritratony,项目名称:tensorflow,代码行数:27,代码来源:session_support.py

示例14: _train_op_fn

 def _train_op_fn(unused_loss):
   global_step = training_util.get_global_step()
   sdca_model, train_op = optimizer.get_train_step(
       columns_to_variables, weight_column_name, loss_type, features, labels,
       global_step)
   if update_weights_hook is not None:
     update_weights_hook.set_parameters(sdca_model, train_op)
   return train_op
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:8,代码来源:sdca_estimator.py

示例15: record_summaries_every_n_global_steps

def record_summaries_every_n_global_steps(n):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  with ops.device("cpu:0"):
    collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
  yield
  collection_ref[:] = old
开发者ID:dyoung418,项目名称:tensorflow,代码行数:8,代码来源:summary_ops.py


注:本文中的tensorflow.python.training.training_util.get_global_step函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。