本文整理汇总了Python中tensorflow.python.training.session_run_hook.SessionRunHook方法的典型用法代码示例。如果您正苦于以下问题:Python session_run_hook.SessionRunHook方法的具体用法?Python session_run_hook.SessionRunHook怎么用?Python session_run_hook.SessionRunHook使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.training.session_run_hook
的用法示例。
在下文中一共展示了session_run_hook.SessionRunHook方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: after_run
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def after_run(self, run_context, run_values):
_ = run_context
scalar_stopping_signal = run_values.results
if _StopSignals.should_stop(scalar_stopping_signal):
# NOTE(xiejw): In prediction, stopping signals are inserted for each
# batch. And we append one more batch to signal the system it should stop.
# The data flow might look like
#
# batch 0: images, labels, stop = 0 (user provided)
# batch 1: images, labels, stop = 0 (user provided)
# ...
# batch 99: images, labels, stop = 0 (user provided)
# batch 100: images, labels, stop = 1 (TPUEstimator appended)
#
# where the final batch (id = 100) is appended by TPUEstimator, so we
# should drop it before returning the predictions to user.
# To achieve that, we throw the OutOfRangeError in after_run. Once
# Monitored Session sees this error in SessionRunHook.after_run, the
# "current" prediction, i.e., batch with id=100, will be discarded
# immediately
raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')
示例2: fit
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
monitors=None, max_steps=None):
"""See trainable.Trainable."""
# TODO(roumposg): Remove when deprecated monitors are removed.
if monitors is None:
monitors = []
deprecated_monitors = [
m for m in monitors
if not isinstance(m, session_run_hook.SessionRunHook)
]
for monitor in deprecated_monitors:
monitor.set_estimator(self)
monitor._lock_estimator() # pylint: disable=protected-access
if self._additional_run_hook:
monitors.append(self._additional_run_hook)
result = self._estimator.fit(x=x, y=y, input_fn=input_fn, steps=steps,
batch_size=batch_size, monitors=monitors,
max_steps=max_steps)
for monitor in deprecated_monitors:
monitor._unlock_estimator() # pylint: disable=protected-access
return result
示例3: should_stop
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def should_stop(scalar_stopping_signal):
"""Detects whether scalar_stopping_signal indicates stopping."""
if isinstance(scalar_stopping_signal, ops.Tensor):
# STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
# way to express the bool check whether scalar_stopping_signal is True.
return math_ops.logical_and(
scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL)
else:
# For non Tensor case, it is used in SessionRunHook. So, we cannot modify
# the graph anymore. Here, we use pure Python.
return bool(scalar_stopping_signal)
示例4: should_stop
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def should_stop(scalar_stopping_signal):
"""Detects whether scalar_stopping_signal indicates stopping."""
if isinstance(scalar_stopping_signal, ops.Tensor):
# STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
# way to express the bool check whether scalar_stopping_signal is True.
return math_ops.logical_and(scalar_stopping_signal,
_StopSignals.STOPPING_SIGNAL)
else:
# For non Tensor case, it is used in SessionRunHook. So, we cannot modify
# the graph anymore. Here, we use pure Python.
return bool(scalar_stopping_signal)
示例5: dataset_initializer_hook
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def dataset_initializer_hook(self):
"""Returns a `SessionRunHook` to initialize this dataset.
This must be called before `features_and_labels`.
"""
iterator = self._dataset.make_initializable_iterator()
# pylint: disable=protected-access
hook = estimator_util._DatasetInitializerHook(iterator)
# pylint: enable=protected-access
self._iterator = iterator
return hook
示例6: _validate_hooks
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def _validate_hooks(hooks):
"""Validates the `hooks`."""
hooks = tuple(hooks or [])
for hook in hooks:
if not isinstance(hook, session_run_hook.SessionRunHook):
raise TypeError(
'All hooks must be `SessionRunHook` instances, given: {}'.format(
hook))
return hooks
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:11,代码来源:training.py
示例7: __new__
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def __new__(cls,
input_fn,
max_steps=None,
hooks=None):
"""Creates a validated `TrainSpec` instance.
Args:
input_fn: Training input function returning a tuple of:
features - `Tensor` or dictionary of string feature name to `Tensor`.
labels - `Tensor` or dictionary of `Tensor` with labels.
max_steps: Int. Positive number of total steps for which to train model.
If `None`, train forever. The training `input_fn` is not expected to
generate `OutOfRangeError` or `StopIteration` exceptions. See the
`train_and_evaluate` stop condition section for details.
hooks: Iterable of `tf.train.SessionRunHook` objects to run
on all workers (including chief) during training.
Returns:
A validated `TrainSpec` object.
Raises:
ValueError: If any of the input arguments is invalid.
TypeError: If any of the arguments is not of the expected type.
"""
# Validate input_fn.
_validate_input_fn(input_fn)
# Validate max_steps.
if max_steps is not None and max_steps <= 0:
raise ValueError(
'Must specify max_steps > 0, given: {}'.format(max_steps))
# Validate hooks.
hooks = _validate_hooks(hooks)
return super(TrainSpec, cls).__new__(
cls,
input_fn=input_fn,
max_steps=max_steps,
hooks=hooks)
开发者ID:PacktPublishing,项目名称:Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda,代码行数:42,代码来源:training.py
示例8: __new__
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def __new__(cls,
mode,
predictions=None,
loss=None,
train_op=None,
eval_metrics=None,
export_outputs=None,
scaffold_fn=None,
host_call=None,
training_hooks=None,
evaluation_hooks=None,
prediction_hooks=None):
"""Creates a validated `TPUEstimatorSpec` instance."""
host_calls = {}
if eval_metrics is not None:
host_calls['eval_metrics'] = eval_metrics
if host_call is not None:
host_calls['host_call'] = host_call
_OutfeedHostCall.validate(host_calls)
training_hooks = tuple(training_hooks or [])
evaluation_hooks = tuple(evaluation_hooks or [])
prediction_hooks = tuple(prediction_hooks or [])
for hook in training_hooks + evaluation_hooks + prediction_hooks:
if not isinstance(hook, session_run_hook.SessionRunHook):
raise TypeError('All hooks must be SessionRunHook instances, given: {}'
.format(hook))
return super(TPUEstimatorSpec, cls).__new__(
cls,
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metrics=eval_metrics,
export_outputs=export_outputs,
scaffold_fn=scaffold_fn,
host_call=host_call,
training_hooks=training_hooks,
evaluation_hooks=evaluation_hooks,
prediction_hooks=prediction_hooks)
示例9: testFinalOpsOnEvaluationLoop
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def testFinalOpsOnEvaluationLoop(self):
value_op, update_op = metrics.accuracy(
labels=self._labels, predictions=self._predictions)
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
# Create checkpoint and log directories:
chkpt_dir = tempfile.mkdtemp('tmp_logs')
logdir = tempfile.mkdtemp('tmp_logs2')
# Save initialized variables to a checkpoint directory:
saver = saver_lib.Saver()
with self.cached_session() as sess:
init_op.run()
saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))
class Object(object):
def __init__(self):
self.hook_was_run = False
obj = Object()
# Create a custom session run hook.
class CustomHook(session_run_hook.SessionRunHook):
def __init__(self, obj):
self.obj = obj
def end(self, session):
self.obj.hook_was_run = True
# Now, run the evaluation loop:
accuracy_value = evaluation.evaluation_loop(
'',
chkpt_dir,
logdir,
eval_op=update_op,
final_op=value_op,
hooks=[CustomHook(obj)],
max_number_of_evaluations=1)
self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
# Validate that custom hook ran.
self.assertTrue(obj.hook_was_run)
示例10: __new__
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def __new__(cls,
mode,
predictions=None,
loss=None,
train_op=None,
eval_metrics=None,
export_outputs=None,
scaffold_fn=None,
host_call=None,
training_hooks=None,
evaluation_hooks=None,
prediction_hooks=None):
"""Creates a validated `TPUEstimatorSpec` instance."""
host_calls = {}
if eval_metrics is not None:
host_calls['eval_metrics'] = eval_metrics
if host_call is not None:
host_calls['host_call'] = host_call
_OutfeedHostCall.validate(host_calls)
training_hooks = list(training_hooks or [])
evaluation_hooks = list(evaluation_hooks or [])
prediction_hooks = list(prediction_hooks or [])
for hook in training_hooks + evaluation_hooks + prediction_hooks:
if not isinstance(hook, session_run_hook.SessionRunHook):
raise TypeError(
'All hooks must be SessionRunHook instances, given: {}'.format(
hook))
return super(TPUEstimatorSpec, cls).__new__(
cls,
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metrics=eval_metrics,
export_outputs=export_outputs,
scaffold_fn=scaffold_fn,
host_call=host_call,
training_hooks=training_hooks,
evaluation_hooks=evaluation_hooks,
prediction_hooks=prediction_hooks)
示例11: stop_if_higher_hook
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def stop_if_higher_hook(estimator,
metric_name,
threshold,
eval_dir=None,
min_steps=0,
run_every_secs=60,
run_every_steps=None):
"""Creates hook to stop if the given metric is higher than the threshold.
Usage example:
```python
estimator = ...
# Hook to stop training if accuracy becomes higher than 0.9.
hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
Caveat: Current implementation supports early-stopping both training and
evaluation in local mode. In distributed mode, training can be stopped but
evaluation (where it's a separate job) will indefinitely wait for new model
checkpoints to evaluate, so you will need other means to detect and stop it.
Early-stopping evaluation in distributed mode requires changes in
`train_and_evaluate` API and will be addressed in a future revision.
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
threshold: Numeric threshold for the given metric.
eval_dir: If set, directory containing summary files with eval metrics. By
default, `estimator.eval_dir()` will be used.
min_steps: `int`, stop is never requested if global step is less than this
value. Defaults to 0.
run_every_secs: If specified, calls `should_stop_fn` at an interval of
`run_every_secs` seconds. Defaults to 60 seconds. Either this or
`run_every_steps` must be set.
run_every_steps: If specified, calls `should_stop_fn` every
`run_every_steps` steps. Either this or `run_every_secs` must be set.
Returns:
An early-stopping hook of type `SessionRunHook` that periodically checks
if the given metric is higher than specified threshold and initiates
early stopping if true.
"""
return _stop_if_threshold_crossed_hook(
estimator=estimator,
metric_name=metric_name,
threshold=threshold,
higher_is_better=True,
eval_dir=eval_dir,
min_steps=min_steps,
run_every_secs=run_every_secs,
run_every_steps=run_every_steps)
示例12: stop_if_lower_hook
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def stop_if_lower_hook(estimator,
metric_name,
threshold,
eval_dir=None,
min_steps=0,
run_every_secs=60,
run_every_steps=None):
"""Creates hook to stop if the given metric is lower than the threshold.
Usage example:
```python
estimator = ...
# Hook to stop training if loss becomes lower than 100.
hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
Caveat: Current implementation supports early-stopping both training and
evaluation in local mode. In distributed mode, training can be stopped but
evaluation (where it's a separate job) will indefinitely wait for new model
checkpoints to evaluate, so you will need other means to detect and stop it.
Early-stopping evaluation in distributed mode requires changes in
`train_and_evaluate` API and will be addressed in a future revision.
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
threshold: Numeric threshold for the given metric.
eval_dir: If set, directory containing summary files with eval metrics. By
default, `estimator.eval_dir()` will be used.
min_steps: `int`, stop is never requested if global step is less than this
value. Defaults to 0.
run_every_secs: If specified, calls `should_stop_fn` at an interval of
`run_every_secs` seconds. Defaults to 60 seconds. Either this or
`run_every_steps` must be set.
run_every_steps: If specified, calls `should_stop_fn` every
`run_every_steps` steps. Either this or `run_every_secs` must be set.
Returns:
An early-stopping hook of type `SessionRunHook` that periodically checks
if the given metric is lower than specified threshold and initiates
early stopping if true.
"""
return _stop_if_threshold_crossed_hook(
estimator=estimator,
metric_name=metric_name,
threshold=threshold,
higher_is_better=False,
eval_dir=eval_dir,
min_steps=min_steps,
run_every_secs=run_every_secs,
run_every_steps=run_every_steps)
示例13: stop_if_no_increase_hook
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def stop_if_no_increase_hook(estimator,
metric_name,
max_steps_without_increase,
eval_dir=None,
min_steps=0,
run_every_secs=60,
run_every_steps=None):
"""Creates hook to stop if metric does not increase within given max steps.
Usage example:
```python
estimator = ...
# Hook to stop training if accuracy does not increase in over 100000 steps.
hook = early_stopping.stop_if_no_increase_hook(estimator, "accuracy", 100000)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
Caveat: Current implementation supports early-stopping both training and
evaluation in local mode. In distributed mode, training can be stopped but
evaluation (where it's a separate job) will indefinitely wait for new model
checkpoints to evaluate, so you will need other means to detect and stop it.
Early-stopping evaluation in distributed mode requires changes in
`train_and_evaluate` API and will be addressed in a future revision.
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
max_steps_without_increase: `int`, maximum number of training steps with no
increase in the given metric.
eval_dir: If set, directory containing summary files with eval metrics. By
default, `estimator.eval_dir()` will be used.
min_steps: `int`, stop is never requested if global step is less than this
value. Defaults to 0.
run_every_secs: If specified, calls `should_stop_fn` at an interval of
`run_every_secs` seconds. Defaults to 60 seconds. Either this or
`run_every_steps` must be set.
run_every_steps: If specified, calls `should_stop_fn` every
`run_every_steps` steps. Either this or `run_every_secs` must be set.
Returns:
An early-stopping hook of type `SessionRunHook` that periodically checks
if the given metric shows no increase over given maximum number of
training steps, and initiates early stopping if true.
"""
return _stop_if_no_metric_improvement_hook(
estimator=estimator,
metric_name=metric_name,
max_steps_without_improvement=max_steps_without_increase,
higher_is_better=True,
eval_dir=eval_dir,
min_steps=min_steps,
run_every_secs=run_every_secs,
run_every_steps=run_every_steps)
示例14: stop_if_no_decrease_hook
# 需要导入模块: from tensorflow.python.training import session_run_hook [as 别名]
# 或者: from tensorflow.python.training.session_run_hook import SessionRunHook [as 别名]
def stop_if_no_decrease_hook(estimator,
metric_name,
max_steps_without_decrease,
eval_dir=None,
min_steps=0,
run_every_secs=60,
run_every_steps=None):
"""Creates hook to stop if metric does not decrease within given max steps.
Usage example:
```python
estimator = ...
# Hook to stop training if loss does not decrease in over 100000 steps.
hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
Caveat: Current implementation supports early-stopping both training and
evaluation in local mode. In distributed mode, training can be stopped but
evaluation (where it's a separate job) will indefinitely wait for new model
checkpoints to evaluate, so you will need other means to detect and stop it.
Early-stopping evaluation in distributed mode requires changes in
`train_and_evaluate` API and will be addressed in a future revision.
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
max_steps_without_decrease: `int`, maximum number of training steps with no
decrease in the given metric.
eval_dir: If set, directory containing summary files with eval metrics. By
default, `estimator.eval_dir()` will be used.
min_steps: `int`, stop is never requested if global step is less than this
value. Defaults to 0.
run_every_secs: If specified, calls `should_stop_fn` at an interval of
`run_every_secs` seconds. Defaults to 60 seconds. Either this or
`run_every_steps` must be set.
run_every_steps: If specified, calls `should_stop_fn` every
`run_every_steps` steps. Either this or `run_every_secs` must be set.
Returns:
An early-stopping hook of type `SessionRunHook` that periodically checks
if the given metric shows no decrease over given maximum number of
training steps, and initiates early stopping if true.
"""
return _stop_if_no_metric_improvement_hook(
estimator=estimator,
metric_name=metric_name,
max_steps_without_improvement=max_steps_without_decrease,
higher_is_better=False,
eval_dir=eval_dir,
min_steps=min_steps,
run_every_secs=run_every_secs,
run_every_steps=run_every_steps)