当前位置: 首页>>代码示例>>Python>>正文


Python session_run_hook.SessionRunHook方法代码示例

本文整理汇总了Python中tensorflow.python.training.session_run_hook.SessionRunHook方法的典型用法代码示例。如果您正苦于以下问题:Python session_run_hook.SessionRunHook方法的具体用法?Python session_run_hook.SessionRunHook怎么用?Python session_run_hook.SessionRunHook使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.python.training.session_run_hook的用法示例。


在下文中一共展示了session_run_hook.SessionRunHook方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: after_run

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def after_run(self, run_context, run_values):
    _ = run_context
    scalar_stopping_signal = run_values.results
    if _StopSignals.should_stop(scalar_stopping_signal):
      # NOTE(xiejw): In prediction, stopping signals are inserted for each
      # batch. And we append one more batch to signal the system it should stop.
      # The data flow might look like
      #
      #  batch   0: images, labels, stop = 0  (user provided)
      #  batch   1: images, labels, stop = 0  (user provided)
      #  ...
      #  batch  99: images, labels, stop = 0  (user provided)
      #  batch 100: images, labels, stop = 1  (TPUEstimator appended)
      #
      # where the final batch (id = 100) is appended by TPUEstimator, so we
      # should drop it before returning the predictions to user.
      # To achieve that, we throw the OutOfRangeError in after_run. Once
      # Monitored Session sees this error in SessionRunHook.after_run, the
      # "current" prediction, i.e., batch with id=100, will be discarded
      # immediately
      raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') 
开发者ID:ymcui,项目名称:Chinese-XLNet,代码行数:23,代码来源:tpu_estimator.py

示例2: fit

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
          monitors=None, max_steps=None):
    """See trainable.Trainable."""
    # TODO(roumposg): Remove when deprecated monitors are removed.
    if monitors is None:
      monitors = []
    deprecated_monitors = [
        m for m in monitors
        if not isinstance(m, session_run_hook.SessionRunHook)
    ]
    for monitor in deprecated_monitors:
      monitor.set_estimator(self)
      monitor._lock_estimator()  # pylint: disable=protected-access

    if self._additional_run_hook:
      monitors.append(self._additional_run_hook)
    result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps,
                                 batch_size=batch_size, monitors=monitors,
                                 max_steps=max_steps)

    for monitor in deprecated_monitors:
      monitor._unlock_estimator()  # pylint: disable=protected-access

    return result 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:26,代码来源:linear.py

示例3: should_stop

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def should_stop(scalar_stopping_signal):
    """Detects whether scalar_stopping_signal indicates stopping."""
    if isinstance(scalar_stopping_signal, ops.Tensor):
      # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
      # way to express the bool check whether scalar_stopping_signal is True.
      return math_ops.logical_and(
          scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL)
    else:
      # For non Tensor case, it is used in SessionRunHook. So, we cannot modify
      # the graph anymore. Here, we use pure Python.
      return bool(scalar_stopping_signal) 
开发者ID:kimiyoung,项目名称:transformer-xl,代码行数:13,代码来源:tpu_estimator.py

示例4: should_stop

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def should_stop(scalar_stopping_signal):
    """Detects whether scalar_stopping_signal indicates stopping."""
    if isinstance(scalar_stopping_signal, ops.Tensor):
      # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
      # way to express the bool check whether scalar_stopping_signal is True.
      return math_ops.logical_and(scalar_stopping_signal,
                                  _StopSignals.STOPPING_SIGNAL)
    else:
      # For non Tensor case, it is used in SessionRunHook. So, we cannot modify
      # the graph anymore. Here, we use pure Python.
      return bool(scalar_stopping_signal) 
开发者ID:ymcui,项目名称:Chinese-XLNet,代码行数:13,代码来源:tpu_estimator.py

示例5: dataset_initializer_hook

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def dataset_initializer_hook(self):
    """Returns a `SessionRunHook` to initialize this dataset.

    This must be called before `features_and_labels`.
    """
    iterator = self._dataset.make_initializable_iterator()
    # pylint: disable=protected-access
    hook = estimator_util._DatasetInitializerHook(iterator)
    # pylint: enable=protected-access
    self._iterator = iterator
    return hook 
开发者ID:kimiyoung,项目名称:transformer-xl,代码行数:13,代码来源:tpu_estimator.py

示例6: _validate_hooks

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def _validate_hooks(hooks):
  """Validates the `hooks`."""
  hooks = tuple(hooks or [])
  for hook in hooks:
    if not isinstance(hook, session_run_hook.SessionRunHook):
      raise TypeError(
          'All hooks must be `SessionRunHook` instances, given: {}'.format(
              hook))
  return hooks 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:11,代码来源:training.py

示例7: __new__

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def __new__(cls,
              input_fn,
              max_steps=None,
              hooks=None):
    """Creates a validated `TrainSpec` instance.

    Args:
      input_fn: Training input function returning a tuple of:
          features - `Tensor` or dictionary of string feature name to `Tensor`.
          labels - `Tensor` or dictionary of `Tensor` with labels.
      max_steps: Int. Positive number of total steps for which to train model.
        If `None`, train forever. The training `input_fn` is not expected to
        generate `OutOfRangeError` or `StopIteration` exceptions. See the
        `train_and_evaluate` stop condition section for details.
      hooks: Iterable of `tf.train.SessionRunHook` objects to run
        on all workers (including chief) during training.

    Returns:
      A validated `TrainSpec` object.

    Raises:
      ValueError: If any of the input arguments is invalid.
      TypeError: If any of the arguments is not of the expected type.
    """
    # Validate input_fn.
    _validate_input_fn(input_fn)

    # Validate max_steps.
    if max_steps is not None and max_steps <= 0:
      raise ValueError(
          'Must specify max_steps > 0, given: {}'.format(max_steps))

    # Validate hooks.
    hooks = _validate_hooks(hooks)

    return super(TrainSpec, cls).__new__(
        cls,
        input_fn=input_fn,
        max_steps=max_steps,
        hooks=hooks) 
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:42,代码来源:training.py

示例8: __new__

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def __new__(cls,
              mode,
              predictions=None,
              loss=None,
              train_op=None,
              eval_metrics=None,
              export_outputs=None,
              scaffold_fn=None,
              host_call=None,
              training_hooks=None,
              evaluation_hooks=None,
              prediction_hooks=None):
    """Creates a validated `TPUEstimatorSpec` instance."""
    host_calls = {}
    if eval_metrics is not None:
      host_calls['eval_metrics'] = eval_metrics
    if host_call is not None:
      host_calls['host_call'] = host_call
    _OutfeedHostCall.validate(host_calls)

    training_hooks = tuple(training_hooks or [])
    evaluation_hooks = tuple(evaluation_hooks or [])
    prediction_hooks = tuple(prediction_hooks or [])

    for hook in training_hooks + evaluation_hooks + prediction_hooks:
      if not isinstance(hook, session_run_hook.SessionRunHook):
        raise TypeError('All hooks must be SessionRunHook instances, given: {}'
                        .format(hook))

    return super(TPUEstimatorSpec, cls).__new__(
        cls,
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metrics=eval_metrics,
        export_outputs=export_outputs,
        scaffold_fn=scaffold_fn,
        host_call=host_call,
        training_hooks=training_hooks,
        evaluation_hooks=evaluation_hooks,
        prediction_hooks=prediction_hooks) 
开发者ID:ymcui,项目名称:Chinese-XLNet,代码行数:44,代码来源:tpu_estimator.py

示例9: testFinalOpsOnEvaluationLoop

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def testFinalOpsOnEvaluationLoop(self):
    value_op, update_op = metrics.accuracy(
        labels=self._labels, predictions=self._predictions)
    init_op = control_flow_ops.group(variables.global_variables_initializer(),
                                     variables.local_variables_initializer())
    # Create checkpoint and log directories:
    chkpt_dir = tempfile.mkdtemp('tmp_logs')
    logdir = tempfile.mkdtemp('tmp_logs2')

    # Save initialized variables to a checkpoint directory:
    saver = saver_lib.Saver()
    with self.cached_session() as sess:
      init_op.run()
      saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))

    class Object(object):

      def __init__(self):
        self.hook_was_run = False

    obj = Object()

    # Create a custom session run hook.
    class CustomHook(session_run_hook.SessionRunHook):

      def __init__(self, obj):
        self.obj = obj

      def end(self, session):
        self.obj.hook_was_run = True

    # Now, run the evaluation loop:
    accuracy_value = evaluation.evaluation_loop(
        '',
        chkpt_dir,
        logdir,
        eval_op=update_op,
        final_op=value_op,
        hooks=[CustomHook(obj)],
        max_number_of_evaluations=1)
    self.assertAlmostEqual(accuracy_value, self._expected_accuracy)

    # Validate that custom hook ran.
    self.assertTrue(obj.hook_was_run) 
开发者ID:google-research,项目名称:tf-slim,代码行数:46,代码来源:evaluation_test.py

示例10: __new__

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def __new__(cls,
              mode,
              predictions=None,
              loss=None,
              train_op=None,
              eval_metrics=None,
              export_outputs=None,
              scaffold_fn=None,
              host_call=None,
              training_hooks=None,
              evaluation_hooks=None,
              prediction_hooks=None):
    """Creates a validated `TPUEstimatorSpec` instance."""
    host_calls = {}
    if eval_metrics is not None:
      host_calls['eval_metrics'] = eval_metrics
    if host_call is not None:
      host_calls['host_call'] = host_call
    _OutfeedHostCall.validate(host_calls)

    training_hooks = list(training_hooks or [])
    evaluation_hooks = list(evaluation_hooks or [])
    prediction_hooks = list(prediction_hooks or [])

    for hook in training_hooks + evaluation_hooks + prediction_hooks:
      if not isinstance(hook, session_run_hook.SessionRunHook):
        raise TypeError(
            'All hooks must be SessionRunHook instances, given: {}'.format(
                hook))

    return super(TPUEstimatorSpec, cls).__new__(
        cls,
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metrics=eval_metrics,
        export_outputs=export_outputs,
        scaffold_fn=scaffold_fn,
        host_call=host_call,
        training_hooks=training_hooks,
        evaluation_hooks=evaluation_hooks,
        prediction_hooks=prediction_hooks) 
开发者ID:kimiyoung,项目名称:transformer-xl,代码行数:45,代码来源:tpu_estimator.py

示例11: stop_if_higher_hook

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def stop_if_higher_hook(estimator,
                        metric_name,
                        threshold,
                        eval_dir=None,
                        min_steps=0,
                        run_every_secs=60,
                        run_every_steps=None):
  """Creates hook to stop if the given metric is higher than the threshold.

  Usage example:

  ```python
  estimator = ...
  # Hook to stop training if accuracy becomes higher than 0.9.
  hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9)
  train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
  tf.estimator.train_and_evaluate(estimator, train_spec, ...)
  ```

  Caveat: Current implementation supports early-stopping both training and
  evaluation in local mode. In distributed mode, training can be stopped but
  evaluation (where it's a separate job) will indefinitely wait for new model
  checkpoints to evaluate, so you will need other means to detect and stop it.
  Early-stopping evaluation in distributed mode requires changes in
  `train_and_evaluate` API and will be addressed in a future revision.

  Args:
    estimator: A `tf.estimator.Estimator` instance.
    metric_name: `str`, metric to track. "loss", "accuracy", etc.
    threshold: Numeric threshold for the given metric.
    eval_dir: If set, directory containing summary files with eval metrics. By
      default, `estimator.eval_dir()` will be used.
    min_steps: `int`, stop is never requested if global step is less than this
      value. Defaults to 0.
    run_every_secs: If specified, calls `should_stop_fn` at an interval of
      `run_every_secs` seconds. Defaults to 60 seconds. Either this or
      `run_every_steps` must be set.
    run_every_steps: If specified, calls `should_stop_fn` every
      `run_every_steps` steps. Either this or `run_every_secs` must be set.

  Returns:
    An early-stopping hook of type `SessionRunHook` that periodically checks
    if the given metric is higher than specified threshold and initiates
    early stopping if true.
  """
  return _stop_if_threshold_crossed_hook(
      estimator=estimator,
      metric_name=metric_name,
      threshold=threshold,
      higher_is_better=True,
      eval_dir=eval_dir,
      min_steps=min_steps,
      run_every_secs=run_every_secs,
      run_every_steps=run_every_steps) 
开发者ID:tensorflow,项目名称:estimator,代码行数:56,代码来源:early_stopping.py

示例12: stop_if_lower_hook

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def stop_if_lower_hook(estimator,
                       metric_name,
                       threshold,
                       eval_dir=None,
                       min_steps=0,
                       run_every_secs=60,
                       run_every_steps=None):
  """Creates hook to stop if the given metric is lower than the threshold.

  Usage example:

  ```python
  estimator = ...
  # Hook to stop training if loss becomes lower than 100.
  hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100)
  train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
  tf.estimator.train_and_evaluate(estimator, train_spec, ...)
  ```

  Caveat: Current implementation supports early-stopping both training and
  evaluation in local mode. In distributed mode, training can be stopped but
  evaluation (where it's a separate job) will indefinitely wait for new model
  checkpoints to evaluate, so you will need other means to detect and stop it.
  Early-stopping evaluation in distributed mode requires changes in
  `train_and_evaluate` API and will be addressed in a future revision.

  Args:
    estimator: A `tf.estimator.Estimator` instance.
    metric_name: `str`, metric to track. "loss", "accuracy", etc.
    threshold: Numeric threshold for the given metric.
    eval_dir: If set, directory containing summary files with eval metrics. By
      default, `estimator.eval_dir()` will be used.
    min_steps: `int`, stop is never requested if global step is less than this
      value. Defaults to 0.
    run_every_secs: If specified, calls `should_stop_fn` at an interval of
      `run_every_secs` seconds. Defaults to 60 seconds. Either this or
      `run_every_steps` must be set.
    run_every_steps: If specified, calls `should_stop_fn` every
      `run_every_steps` steps. Either this or `run_every_secs` must be set.

  Returns:
    An early-stopping hook of type `SessionRunHook` that periodically checks
    if the given metric is lower than specified threshold and initiates
    early stopping if true.
  """
  return _stop_if_threshold_crossed_hook(
      estimator=estimator,
      metric_name=metric_name,
      threshold=threshold,
      higher_is_better=False,
      eval_dir=eval_dir,
      min_steps=min_steps,
      run_every_secs=run_every_secs,
      run_every_steps=run_every_steps) 
开发者ID:tensorflow,项目名称:estimator,代码行数:56,代码来源:early_stopping.py

示例13: stop_if_no_increase_hook

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def stop_if_no_increase_hook(estimator,
                             metric_name,
                             max_steps_without_increase,
                             eval_dir=None,
                             min_steps=0,
                             run_every_secs=60,
                             run_every_steps=None):
  """Creates hook to stop if metric does not increase within given max steps.

  Usage example:

  ```python
  estimator = ...
  # Hook to stop training if accuracy does not increase in over 100000 steps.
  hook = early_stopping.stop_if_no_increase_hook(estimator, "accuracy", 100000)
  train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
  tf.estimator.train_and_evaluate(estimator, train_spec, ...)
  ```

  Caveat: Current implementation supports early-stopping both training and
  evaluation in local mode. In distributed mode, training can be stopped but
  evaluation (where it's a separate job) will indefinitely wait for new model
  checkpoints to evaluate, so you will need other means to detect and stop it.
  Early-stopping evaluation in distributed mode requires changes in
  `train_and_evaluate` API and will be addressed in a future revision.

  Args:
    estimator: A `tf.estimator.Estimator` instance.
    metric_name: `str`, metric to track. "loss", "accuracy", etc.
    max_steps_without_increase: `int`, maximum number of training steps with no
      increase in the given metric.
    eval_dir: If set, directory containing summary files with eval metrics. By
      default, `estimator.eval_dir()` will be used.
    min_steps: `int`, stop is never requested if global step is less than this
      value. Defaults to 0.
    run_every_secs: If specified, calls `should_stop_fn` at an interval of
      `run_every_secs` seconds. Defaults to 60 seconds. Either this or
      `run_every_steps` must be set.
    run_every_steps: If specified, calls `should_stop_fn` every
      `run_every_steps` steps. Either this or `run_every_secs` must be set.

  Returns:
    An early-stopping hook of type `SessionRunHook` that periodically checks
    if the given metric shows no increase over given maximum number of
    training steps, and initiates early stopping if true.
  """
  return _stop_if_no_metric_improvement_hook(
      estimator=estimator,
      metric_name=metric_name,
      max_steps_without_improvement=max_steps_without_increase,
      higher_is_better=True,
      eval_dir=eval_dir,
      min_steps=min_steps,
      run_every_secs=run_every_secs,
      run_every_steps=run_every_steps) 
开发者ID:tensorflow,项目名称:estimator,代码行数:57,代码来源:early_stopping.py

示例14: stop_if_no_decrease_hook

# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def stop_if_no_decrease_hook(estimator,
                             metric_name,
                             max_steps_without_decrease,
                             eval_dir=None,
                             min_steps=0,
                             run_every_secs=60,
                             run_every_steps=None):
  """Creates hook to stop if metric does not decrease within given max steps.

  Usage example:

  ```python
  estimator = ...
  # Hook to stop training if loss does not decrease in over 100000 steps.
  hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
  train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
  tf.estimator.train_and_evaluate(estimator, train_spec, ...)
  ```

  Caveat: Current implementation supports early-stopping both training and
  evaluation in local mode. In distributed mode, training can be stopped but
  evaluation (where it's a separate job) will indefinitely wait for new model
  checkpoints to evaluate, so you will need other means to detect and stop it.
  Early-stopping evaluation in distributed mode requires changes in
  `train_and_evaluate` API and will be addressed in a future revision.

  Args:
    estimator: A `tf.estimator.Estimator` instance.
    metric_name: `str`, metric to track. "loss", "accuracy", etc.
    max_steps_without_decrease: `int`, maximum number of training steps with no
      decrease in the given metric.
    eval_dir: If set, directory containing summary files with eval metrics. By
      default, `estimator.eval_dir()` will be used.
    min_steps: `int`, stop is never requested if global step is less than this
      value. Defaults to 0.
    run_every_secs: If specified, calls `should_stop_fn` at an interval of
      `run_every_secs` seconds. Defaults to 60 seconds. Either this or
      `run_every_steps` must be set.
    run_every_steps: If specified, calls `should_stop_fn` every
      `run_every_steps` steps. Either this or `run_every_secs` must be set.

  Returns:
    An early-stopping hook of type `SessionRunHook` that periodically checks
    if the given metric shows no decrease over given maximum number of
    training steps, and initiates early stopping if true.
  """
  return _stop_if_no_metric_improvement_hook(
      estimator=estimator,
      metric_name=metric_name,
      max_steps_without_improvement=max_steps_without_decrease,
      higher_is_better=False,
      eval_dir=eval_dir,
      min_steps=min_steps,
      run_every_secs=run_every_secs,
      run_every_steps=run_every_steps) 
开发者ID:tensorflow,项目名称:estimator,代码行数:57,代码来源:early_stopping.py


注:本文中的tensorflow.python.training.session_run_hook.SessionRunHook方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。