本文整理汇总了Python中tensorflow.contrib.framework.python.ops.variables.get_or_create_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python get_or_create_global_step函数的具体用法?Python get_or_create_global_step怎么用?Python get_or_create_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_or_create_global_step函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testEvaluationLoopTimeout
def testEvaluationLoopTimeout(self):
checkpoint_dir = os.path.join(self.get_temp_dir(),
'evaluation_loop_timeout')
if not gfile.Exists(checkpoint_dir):
gfile.MakeDirs(checkpoint_dir)
# We need a variable that that the saver will try to restore.
variables.get_or_create_global_step()
# Run with placeholders. If we actually try to evaluate this, we'd fail
# since we're not using a feed_dict.
cant_run_op = array_ops.placeholder(dtype=dtypes.float32)
start = time.time()
final_values = evaluation.evaluate_repeatedly(
checkpoint_dir=checkpoint_dir,
eval_ops=cant_run_op,
hooks=[evaluation.StopAfterNEvalsHook(10)],
timeout=6)
end = time.time()
self.assertFalse(final_values)
# Assert that we've waited for the duration of the timeout (minus the sleep
# time).
self.assertGreater(end - start, 5.0)
# Then the timeout kicked in and stops the loop.
self.assertLess(end - start, 7)
示例2: testEvaluateWithEvalFeedDict
def testEvaluateWithEvalFeedDict(self):
# Create a checkpoint.
checkpoint_dir = os.path.join(self.get_temp_dir(),
'evaluate_with_eval_feed_dict')
self._train_model(checkpoint_dir, num_steps=1)
# We need a variable that that the saver will try to restore.
variables.get_or_create_global_step()
# Create a variable and an eval op that increments it with a placeholder.
my_var = variables.local_variable(0.0, name='my_var')
increment = array_ops.placeholder(dtype=dtypes.float32)
eval_ops = state_ops.assign_add(my_var, increment)
increment_value = 3
num_evals = 5
expected_value = increment_value * num_evals
final_values = evaluation.evaluate_repeatedly(
checkpoint_dir=checkpoint_dir,
eval_ops=eval_ops,
feed_dict={increment: 3},
final_ops={'my_var': array_ops.identity(my_var)},
hooks=[evaluation.StopAfterNEvalsHook(num_evals),],
max_number_of_evaluations=1)
self.assertEqual(final_values['my_var'], expected_value)
示例3: test_two_listeners_with_default_saver
def test_two_listeners_with_default_saver(self):
with ops.Graph().as_default():
global_step = variables.get_or_create_global_step()
train_op = state_ops.assign_add(global_step, 1)
listener1 = MockCheckpointSaverListener()
listener2 = MockCheckpointSaverListener()
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir,
save_steps=1,
listeners=[listener1, listener2])
with monitored_session.SingularMonitoredSession(
hooks=[hook],
checkpoint_dir=self.model_dir) as sess:
sess.run(train_op)
sess.run(train_op)
global_step_val = sess.run(global_step)
listener1_counts = listener1.get_counts()
listener2_counts = listener2.get_counts()
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
'before_save': 2,
'after_save': 2,
'end': 1
}, listener1_counts)
self.assertEqual(listener1_counts, listener2_counts)
with ops.Graph().as_default():
global_step = variables.get_or_create_global_step()
with monitored_session.SingularMonitoredSession(
checkpoint_dir=self.model_dir) as sess2:
global_step_saved_val = sess2.run(global_step)
self.assertEqual(2, global_step_saved_val)
示例4: test_step_counter_every_n_secs
def test_step_counter_every_n_secs(self):
with ops.Graph().as_default() as g, session_lib.Session() as sess:
variables.get_or_create_global_step()
train_op = training_util._increment_global_step(1)
summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
hook = basic_session_run_hooks.StepCounterHook(
summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(train_op)
time.sleep(0.2)
mon_sess.run(train_op)
time.sleep(0.2)
mon_sess.run(train_op)
hook.end(sess)
summary_writer.assert_summaries(
test_case=self,
expected_logdir=self.log_dir,
expected_graph=g,
expected_summaries={})
self.assertTrue(summary_writer.summaries, 'No summaries were created.')
self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
for summary in summary_writer.summaries.values():
summary_value = summary[0].value[0]
self.assertEqual('global_step/sec', summary_value.tag)
self.assertGreater(summary_value.simple_value, 0)
示例5: test_step_counter_every_n_steps
def test_step_counter_every_n_steps(self):
with ops.Graph().as_default() as g, session_lib.Session() as sess:
variables.get_or_create_global_step()
train_op = training_util._increment_global_step(1)
summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
hook = basic_session_run_hooks.StepCounterHook(
summary_writer=summary_writer, every_n_steps=10)
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
with test.mock.patch.object(tf_logging, 'warning') as mock_log:
for _ in range(30):
time.sleep(0.01)
mon_sess.run(train_op)
# logging.warning should not be called.
self.assertIsNone(mock_log.call_args)
hook.end(sess)
summary_writer.assert_summaries(
test_case=self,
expected_logdir=self.log_dir,
expected_graph=g,
expected_summaries={})
self.assertItemsEqual([11, 21], summary_writer.summaries.keys())
for step in [11, 21]:
summary_value = summary_writer.summaries[step][0].value[0]
self.assertEqual('global_step/sec', summary_value.tag)
self.assertGreater(summary_value.simple_value, 0)
示例6: test_not_wait_for_step_zero
def test_not_wait_for_step_zero(self):
with ops.Graph().as_default():
variables.get_or_create_global_step()
hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)
hook.begin()
with session_lib.Session() as sess:
# Before run should return without waiting gstep increment.
hook.before_run(
session_run_hook.SessionRunContext(
original_args=None, session=sess))
示例7: setUp
def setUp(self):
test.TestCase.setUp(self)
self.log_dir = 'log/dir'
self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
var = variables_lib.Variable(0.0)
tensor = state_ops.assign_add(var, 1.0)
tensor2 = tensor * 2
self.summary_op = summary_lib.scalar('my_summary', tensor)
self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)
variables.get_or_create_global_step()
self.train_op = training_util._increment_global_step(1)
示例8: test_recover_and_retry_on_aborted_error
def test_recover_and_retry_on_aborted_error(self):
# Tests that we silently retry and recover on abort. This test uses
# a CheckpointSaver to have something to recover from.
logdir = _test_dir(self.get_temp_dir(),
'test_recover_and_retry_on_aborted_error')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
scaffold = monitored_session.Scaffold()
abort_hook = RaiseOnceAtCountN(
4, errors_impl.AbortedError(None, None, 'Abort'))
# Save after each step.
ckpt_hook = basic_session_run_hooks.CheckpointSaverHook(
logdir, save_steps=1, scaffold=scaffold)
hooks = [abort_hook, ckpt_hook]
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir),
hooks=hooks) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Here at step 3, the hook triggers and raises AbortedError. The
# MonitoredSession automatically restores and retries.
self.assertEqual(3, session.run(do_step))
self.assertTrue(abort_hook.raised)
self.assertFalse(session.should_stop())
self.assertEqual(4, session.run(do_step))
self.assertFalse(session.should_stop())
示例9: test_recovery
def test_recovery(self):
logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
scaffold = monitored_session.Scaffold()
# Use a hook to save the model every 100 steps. It also saves it at
# the end.
hooks = [
basic_session_run_hooks.CheckpointSaverHook(
logdir, save_steps=1, scaffold=scaffold)
]
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir),
hooks=hooks) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir)) as session:
self.assertEqual(2, session.run(gstep))
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
logdir))) as session:
self.assertEqual(2, session.run(gstep))
示例10: test_num_steps
def test_num_steps(self):
logdir = _test_dir(self.get_temp_dir(), 'test_num_steps')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
# Do 3 steps and save.
hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)]
scaffold = monitored_session.Scaffold().finalize()
with monitored_session.MonitoredSession(hooks=hooks) as session:
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertTrue(session.should_stop())
save_path = scaffold.saver.save(session._coordinated_creator.tf_sess,
os.path.join(logdir, 'step-3'))
# Restore and do 4 steps.
def load_ckpt(scaffold, sess):
scaffold.saver.restore(sess, save_path)
session_creator = monitored_session.ChiefSessionCreator(
scaffold=monitored_session.Scaffold(init_fn=load_ckpt))
hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=4)]
with monitored_session.MonitoredSession(
hooks=hooks, session_creator=session_creator) as session:
self.assertEqual(4, session.run(do_step))
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertTrue(session.should_stop())
示例11: __init__
def __init__(
self,
global_step_tensor=None,
init_op=None,
init_feed_dict=None,
init_fn=None,
ready_op=None,
local_init_op=None,
summary_op=None,
saver=None,
keep_checkpoint_max=5,
):
"""Create a scaffold.
Args:
global_step_tensor: Optional tensor to use as the global step counter.
init_op: Optional op for initializing variables.
init_feed_dict: Optional session feed dictionary to use when running the
init_op.
init_fn: Optional function to use to initialize the model after running
the init_op. Will be called as `init_fn(scaffold, session)`.
ready_op: Optional op to verify that the variables are initialized. Must
return an empty scalar string tensor when the variables are
initialized, or a non-empty one listing the names of the
non-initialized variables.
local_init_op: Optional op to initialize local variables.
summary_op: Optional op to gather all summaries. Must return a scalar
string tensor containing a serialized `Summary` proto.
saver: Optional `tf.Saver` object to use to save and restore variables.
keep_checkpoint_max: Optional parameter to use to construct a saver if
none is already there in the graph.
"""
if global_step_tensor is None:
global_step_tensor = contrib_variables.get_or_create_global_step()
self.global_step_tensor = global_step_tensor
if init_op is None:
init_op = Scaffold._get_or_default(ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
self.init_op = init_op
self.init_feed_dict = init_feed_dict
# NOTE(touts): modifying the init function to be passed the scaffold is a
# hack to make it easy to find the saver. Is there a better way?
if init_fn:
self.init_fn = lambda sess: init_fn(self, sess)
else:
self.init_fn = None
if ready_op is None:
ready_op = Scaffold._get_or_default(ops.GraphKeys.READY_OP, variables.report_uninitialized_variables)
self.ready_op = ready_op
if local_init_op is None:
local_init_op = Scaffold._get_or_default(ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op)
self.local_init_op = local_init_op
if summary_op is None:
summary_op = Scaffold._get_or_default(ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries)
# pylint: disable=g-long-lambda
if saver is None:
saver = Scaffold._get_or_default(
ops.GraphKeys.SAVERS, lambda: training_saver.Saver(sharded=True, max_to_keep=keep_checkpoint_max)
)
# pylint: enable=g-long-lambda
self.saver = saver
示例12: _get_train_ops
def _get_train_ops(self, features, targets):
"""See base class."""
if not isinstance(self._linear_optimizer, sdca_optimizer.SDCAOptimizer):
return super(LinearRegressor, self)._get_train_ops(features, targets)
assert not self._joint_weights, ("_joint_weights is incompatible with"
" SDCAOptimizer.")
global_step = contrib_variables.get_or_create_global_step()
logits, columns_to_variables, bias = (
layers.weighted_sum_from_feature_columns(
columns_to_tensors=features,
feature_columns=self._linear_feature_columns,
num_outputs=self._target_column.num_label_columns,
weight_collections=[self._linear_model.get_scope_name()],
scope=self._linear_model.get_scope_name()))
with ops.control_dependencies([self._centered_bias()]):
loss = self._target_column.loss(logits, targets, features)
logging_ops.scalar_summary("loss", loss)
_add_bias_column(self._linear_feature_columns, features, bias, targets,
columns_to_variables)
train_op = self._linear_optimizer.get_train_step(
columns_to_variables, self._target_column.weight_column_name,
self._loss_type(), features, targets, global_step)
return train_op, loss
示例13: __init__
def __init__(self,
log_dir=None,
summary_writer=None,
summary_op=None,
feed_dict=None):
"""Constructs the Summary Hook.
Args:
log_dir: The directory where the summary events are saved to. Used only
when `summary_writer` is not specified.
summary_writer: A `tf.summary.FileWriter` to write summary events with.
summary_op: The summary op to run. If left as `None`, then all summaries
in the tf.GraphKeys.SUMMARIES collection are used.
feed_dict: An optional feed dictionary to use when evaluating the
summaries.
Raises:
ValueError: If both `log_dir` and `summary_writer` are `None`.
"""
self._summary_op = summary_op
self._feed_dict = feed_dict
self._summary_writer = summary_writer
self._log_dir = log_dir
self._summary_writer = summary_writer
if self._log_dir is None and self._summary_writer is None:
raise ValueError('One of log_dir or summary_writer should be used.')
self._global_step = variables.get_or_create_global_step()
示例14: finalize
def finalize(self):
"""Creates operations if needed and finalizes the graph."""
if self._global_step_tensor is None:
self._global_step_tensor = contrib_variables.get_or_create_global_step()
if self._init_op is None:
self._init_op = Scaffold._get_or_default(
'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
if self._ready_op is None:
self._ready_op = Scaffold._get_or_default(
'ready_op', ops.GraphKeys.READY_OP,
variables.report_uninitialized_variables)
if self._local_init_op is None:
self._local_init_op = Scaffold._get_or_default(
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
Scaffold._default_local_init_op)
if self._summary_op is None:
self._summary_op = Scaffold._get_or_default(
'summary_op', ops.GraphKeys.SUMMARY_OP,
logging_ops.merge_all_summaries)
# pylint: disable=g-long-lambda
if self._saver is None:
self._saver = Scaffold._get_or_default(
'saver',
ops.GraphKeys.SAVERS,
lambda: training_saver.Saver(sharded=True,
max_to_keep=self._keep_checkpoint_max))
# pylint: enable=g-long-lambda
ops.get_default_graph().finalize()
示例15: testNoneGlobalStep
def testNoneGlobalStep(self):
with ops.Graph().as_default():
random_seed.set_random_seed(0)
tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
tf_predictions = batchnorm_classifier(tf_inputs)
loss_ops.log_loss(tf_predictions, tf_labels)
total_loss = loss_ops.get_total_loss()
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
train_op = training.create_train_op(
total_loss, optimizer, global_step=None)
global_step = variables_lib.get_or_create_global_step()
with session_lib.Session() as sess:
# Initialize all variables
sess.run(variables_lib2.global_variables_initializer())
for _ in range(10):
sess.run([train_op])
global_step = global_step.eval()
# Since train_op don't use global_step it shouldn't change.
self.assertAllClose(global_step, 0)