本文整理汇总了Python中tensorflow.python.training.basic_session_run_hooks.StepCounterHook方法的典型用法代码示例。如果您正苦于以下问题:Python basic_session_run_hooks.StepCounterHook方法的具体用法?Python basic_session_run_hooks.StepCounterHook怎么用?Python basic_session_run_hooks.StepCounterHook使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.training.basic_session_run_hooks
的用法示例。
在下文中一共展示了basic_session_run_hooks.StepCounterHook方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: MonitoredTrainingSession
# 需要导入模块: from tensorflow.python.training import basic_session_run_hooks [as 别名]
# 或者: from tensorflow.python.training.basic_session_run_hooks import StepCounterHook [as 别名]
def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
is_chief=True,
checkpoint_dir=None,
hooks=None,
scaffold=None,
config=None):
"""Creates a `MonitoredSession` for training.
For a chief, this utility sets proper session initializer/restorer. It also
creates hooks related to checkpoint and summary saving. For workers, this
utility sets proper session creator which waits for the chief to
inialize/restore.
Args:
master: `String` the TensorFlow master to use.
is_chief: If `True`, it will take care of initialization and recovery the
underlying TensorFlow session. If `False`, it will wait on a chief to
initialize or recover the TensorFlow session.
checkpoint_dir: A string. Optional path to a directory where to restore
variables.
hooks: Optional list of `SessionRunHook` objects.
scaffold: A `Scaffold` used for gathering or building supportive ops. If
not specified, a default one is created. It's used to finalize the graph.
config: `ConfigProto` proto used to configure the session.
Returns:
A `MonitoredSession` object.
"""
hooks = hooks or []
scaffold = scaffold or Scaffold()
if not is_chief:
session_creator = WorkerSessionCreator(
scaffold=scaffold, master=master, config=config)
else:
session_creator = ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir,
master=master,
config=config)
hooks.extend([
basic_session_run_hooks.StepCounterHook(output_dir=checkpoint_dir),
basic_session_run_hooks.SummarySaverHook(
scaffold=scaffold, save_steps=100, output_dir=checkpoint_dir),
basic_session_run_hooks.CheckpointSaverHook(
checkpoint_dir, save_secs=600, scaffold=scaffold),
])
return MonitoredSession(session_creator=session_creator, hooks=hooks)
示例2: test_estimator_with_strategy_hooks
# 需要导入模块: from tensorflow.python.training import basic_session_run_hooks [as 别名]
# 或者: from tensorflow.python.training.basic_session_run_hooks import StepCounterHook [as 别名]
def test_estimator_with_strategy_hooks(self, distribution,
use_train_and_evaluate):
config = run_config.RunConfig(eval_distribute=distribution)
def _input_map_fn(tensor):
return {'feature': tensor}, tensor
def input_fn():
return dataset_ops.Dataset.from_tensors(
[1.]).repeat(10).batch(5).map(_input_map_fn)
def model_fn(features, labels, mode):
del features, labels
global_step = training_util.get_global_step()
if mode == model_fn_lib.ModeKeys.TRAIN:
train_hook1 = basic_session_run_hooks.StepCounterHook(
every_n_steps=1, output_dir=self.get_temp_dir())
train_hook2 = tf.compat.v1.test.mock.MagicMock(
wraps=tf.compat.v1.train.SessionRunHook(),
spec=tf.compat.v1.train.SessionRunHook)
return model_fn_lib.EstimatorSpec(
mode,
loss=tf.constant(1.),
train_op=global_step.assign_add(1),
training_hooks=[train_hook1, train_hook2])
if mode == model_fn_lib.ModeKeys.EVAL:
eval_hook1 = basic_session_run_hooks.StepCounterHook(
every_n_steps=1, output_dir=self.get_temp_dir())
eval_hook2 = tf.compat.v1.test.mock.MagicMock(
wraps=tf.compat.v1.train.SessionRunHook(),
spec=tf.compat.v1.train.SessionRunHook)
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=tf.constant(1.),
evaluation_hooks=[eval_hook1, eval_hook2])
num_steps = 10
estimator = estimator_lib.EstimatorV2(
model_fn=model_fn, model_dir=self.get_temp_dir(), config=config)
if use_train_and_evaluate:
training.train_and_evaluate(
estimator, training.TrainSpec(input_fn, max_steps=num_steps),
training.EvalSpec(input_fn))
else:
estimator.train(input_fn, steps=num_steps)
estimator.evaluate(input_fn, steps=num_steps)