本文整理汇总了Python中tensorflow.python.training.monitored_session._HookedSession函数的典型用法代码示例。如果您正苦于以下问题:Python _HookedSession函数的具体用法?Python _HookedSession怎么用?Python _HookedSession使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了_HookedSession函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _validate_print_every_n_secs
def _validate_print_every_n_secs(self, sess, at_end):
t = tf.constant(42.0, name='foo')
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], every_n_secs=1.0, at_end=at_end,
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(tf.global_variables_initializer())
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
# assertNotRegexpMatches is not supported by python 3.1 and later
self._logger.logged_metric = []
mon_sess.run(train_op)
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
time.sleep(1.0)
self._logger.logged_metric = []
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
self._logger.logged_metric = []
hook.end(sess)
if at_end:
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
else:
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
示例2: test_log_tensors
def test_log_tensors(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.train.get_or_create_global_step()
t1 = tf.constant(42.0, name='foo')
t2 = tf.constant(43.0, name='bar')
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t1, t2], at_end=True, metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(tf.global_variables_initializer())
for _ in range(3):
mon_sess.run(train_op)
self.assertEqual(self._logger.logged_metric, [])
hook.end(sess)
self.assertEqual(len(self._logger.logged_metric), 2)
metric1 = self._logger.logged_metric[0]
self.assertRegexpMatches(str(metric1["name"]), "foo")
self.assertEqual(metric1["value"], 42.0)
self.assertEqual(metric1["unit"], None)
self.assertEqual(metric1["global_step"], 0)
metric2 = self._logger.logged_metric[1]
self.assertRegexpMatches(str(metric2["name"]), "bar")
self.assertEqual(metric2["value"], 43.0)
self.assertEqual(metric2["unit"], None)
self.assertEqual(metric2["global_step"], 0)
示例3: test_step_counter_every_n_secs
def test_step_counter_every_n_secs(self):
with tf.Graph().as_default() as g, tf.Session() as sess:
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.assign_add(global_step, 1)
summary_writer = testing.FakeSummaryWriter(self.log_dir, g)
hook = tf.train.StepCounterHook(summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)
hook.begin()
sess.run(tf.initialize_all_variables())
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(train_op)
time.sleep(0.2)
mon_sess.run(train_op)
time.sleep(0.2)
mon_sess.run(train_op)
hook.end(sess)
summary_writer.assert_summaries(
test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={}
)
self.assertTrue(summary_writer.summaries, "No summaries were created.")
self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
for summary in summary_writer.summaries.values():
summary_value = summary[0].value[0]
self.assertEqual("global_step/sec", summary_value.tag)
self.assertGreater(summary_value.simple_value, 0)
示例4: _validate_print_every_n_steps
def _validate_print_every_n_steps(self, sess, at_end):
t = tf.constant(42.0, name="foo")
train_op = tf.constant(3)
hook = metric_hook.LoggingMetricHook(
tensors=[t.name], every_n_iter=10, at_end=at_end,
metric_logger=self._logger)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook]) # pylint: disable=protected-access
sess.run(tf.compat.v1.global_variables_initializer())
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
for _ in range(3):
self._logger.logged_metric = []
for _ in range(9):
mon_sess.run(train_op)
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
mon_sess.run(train_op)
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
# Add additional run to verify proper reset when called multiple times.
self._logger.logged_metric = []
mon_sess.run(train_op)
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
self._logger.logged_metric = []
hook.end(sess)
if at_end:
self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
else:
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
示例5: test_save_secs_saving_once_every_three_steps
def test_save_secs_saving_once_every_three_steps(self, mock_time):
mock_time.return_value = 1484695987.209386
hook = basic_session_run_hooks.SummarySaverHook(
save_secs=9.,
summary_writer=self.summary_writer,
summary_op=self.summary_op)
with self.test_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
for _ in range(8):
mon_sess.run(self.train_op)
mock_time.return_value += 3.1
hook.end(sess)
# 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:
self.summary_writer.assert_summaries(
test_case=self,
expected_logdir=self.log_dir,
expected_summaries={
1: {
'my_summary': 1.0
},
4: {
'my_summary': 2.0
},
7: {
'my_summary': 3.0
},
})
示例6: test_save_steps_saves_periodically
def test_save_steps_saves_periodically(self):
with self.graph.as_default():
hook = tf.train.CheckpointSaverHook(
self.model_dir, save_steps=2, scaffold=self.scaffold)
hook.begin()
self.scaffold.finalize()
with tf.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(self.train_op)
mon_sess.run(self.train_op)
# Not saved
self.assertEqual(1, tf.contrib.framework.load_variable(
self.model_dir, self.global_step.name))
mon_sess.run(self.train_op)
# saved
self.assertEqual(3, tf.contrib.framework.load_variable(
self.model_dir, self.global_step.name))
mon_sess.run(self.train_op)
# Not saved
self.assertEqual(3, tf.contrib.framework.load_variable(
self.model_dir, self.global_step.name))
mon_sess.run(self.train_op)
# saved
self.assertEqual(5, tf.contrib.framework.load_variable(
self.model_dir, self.global_step.name))
示例7: test_capture
def test_capture(self):
global_step = tf.contrib.framework.get_or_create_global_step()
# Some test computation
some_weights = tf.get_variable("weigths", [2, 128])
computation = tf.nn.softmax(some_weights)
hook = hooks.MetadataCaptureHook(
params={"step": 5}, model_dir=self.model_dir,
run_config=tf.contrib.learn.RunConfig())
hook.begin()
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
#pylint: disable=W0212
mon_sess = monitored_session._HookedSession(sess, [hook])
# Should not trigger for step 0
sess.run(tf.assign(global_step, 0))
mon_sess.run(computation)
self.assertEqual(gfile.ListDirectory(self.model_dir), [])
# Should trigger *after* step 5
sess.run(tf.assign(global_step, 5))
mon_sess.run(computation)
self.assertEqual(gfile.ListDirectory(self.model_dir), [])
mon_sess.run(computation)
self.assertEqual(
set(gfile.ListDirectory(self.model_dir)),
set(["run_meta", "tfprof_log", "timeline.json"]))
示例8: _validate_print_every_n_steps
def _validate_print_every_n_steps(self, sess, at_end):
t = constant_op.constant(42.0, name='foo')
train_op = constant_op.constant(3)
hook = basic_session_run_hooks.LoggingTensorHook(
tensors=[t.name], every_n_iter=10, at_end=at_end)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(variables_lib.global_variables_initializer())
mon_sess.run(train_op)
self.assertRegexpMatches(str(self.logged_message), t.name)
for _ in range(3):
self.logged_message = ''
for _ in range(9):
mon_sess.run(train_op)
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self.logged_message).find(t.name), -1)
mon_sess.run(train_op)
self.assertRegexpMatches(str(self.logged_message), t.name)
# Add additional run to verify proper reset when called multiple times.
self.logged_message = ''
mon_sess.run(train_op)
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self.logged_message).find(t.name), -1)
self.logged_message = ''
hook.end(sess)
if at_end:
self.assertRegexpMatches(str(self.logged_message), t.name)
else:
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self.logged_message).find(t.name), -1)
示例9: test_global_step_name
def test_global_step_name(self):
with ops.Graph().as_default() as g, session_lib.Session() as sess:
with variable_scope.variable_scope('bar'):
foo_step = variable_scope.get_variable(
'foo',
initializer=0,
trainable=False,
collections=[
ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES
])
train_op = state_ops.assign_add(foo_step, 1)
summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
hook = basic_session_run_hooks.StepCounterHook(
summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(train_op)
mon_sess.run(train_op)
hook.end(sess)
summary_writer.assert_summaries(
test_case=self,
expected_logdir=self.log_dir,
expected_graph=g,
expected_summaries={})
self.assertTrue(summary_writer.summaries, 'No summaries were created.')
self.assertItemsEqual([2], summary_writer.summaries.keys())
summary_value = summary_writer.summaries[2][0].value[0]
self.assertEqual('bar/foo/sec', summary_value.tag)
示例10: test_summary_saver
def test_summary_saver(self):
with tf.Graph().as_default() as g, tf.Session() as sess:
log_dir = 'log/dir'
summary_writer = testing.FakeSummaryWriter(log_dir, g)
var = tf.Variable(0.0)
tensor = tf.assign_add(var, 1.0)
summary_op = tf.scalar_summary('my_summary', tensor)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.assign_add(global_step, 1)
hook = tf.train.SummarySaverHook(
summary_op=summary_op, save_steps=8, summary_writer=summary_writer)
hook.begin()
sess.run(tf.initialize_all_variables())
mon_sess = monitored_session._HookedSession(sess, [hook])
for i in range(30):
_ = i
mon_sess.run(train_op)
hook.end(sess)
summary_writer.assert_summaries(
test_case=self,
expected_logdir=log_dir,
expected_graph=g,
expected_summaries={
1: {'my_summary': 1.0},
9: {'my_summary': 2.0},
17: {'my_summary': 3.0},
25: {'my_summary': 4.0},
})
示例11: test_stop_based_on_num_step
def test_stop_based_on_num_step(self):
h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
with ops.Graph().as_default():
global_step = variables.get_or_create_global_step()
no_op = control_flow_ops.no_op()
h.begin()
with session_lib.Session() as sess:
mon_sess = monitored_session._HookedSession(sess, [h])
sess.run(state_ops.assign(global_step, 5))
h.after_create_session(sess, None)
mon_sess.run(no_op)
self.assertFalse(mon_sess.should_stop())
sess.run(state_ops.assign(global_step, 13))
mon_sess.run(no_op)
self.assertFalse(mon_sess.should_stop())
sess.run(state_ops.assign(global_step, 14))
mon_sess.run(no_op)
self.assertFalse(mon_sess.should_stop())
sess.run(state_ops.assign(global_step, 15))
mon_sess.run(no_op)
self.assertTrue(mon_sess.should_stop())
sess.run(state_ops.assign(global_step, 16))
mon_sess._should_stop = False
mon_sess.run(no_op)
self.assertTrue(mon_sess.should_stop())
示例12: test_multiple_summaries
def test_multiple_summaries(self):
hook = basic_session_run_hooks.SummarySaverHook(
save_steps=8,
summary_writer=self.summary_writer,
summary_op=[self.summary_op, self.summary_op2])
with self.test_session() as sess:
hook.begin()
sess.run(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
for _ in range(10):
mon_sess.run(self.train_op)
hook.end(sess)
self.summary_writer.assert_summaries(
test_case=self,
expected_logdir=self.log_dir,
expected_summaries={
1: {
'my_summary': 1.0,
'my_summary2': 2.0
},
9: {
'my_summary': 2.0,
'my_summary2': 4.0
},
})
示例13: testDumpingDebugHookWithStatefulLegacyWatchFnWorks
def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self):
watch_fn_state = {"run_counter": 0}
def counting_watch_fn(fetches, feed_dict):
del fetches, feed_dict
watch_fn_state["run_counter"] += 1
if watch_fn_state["run_counter"] % 2 == 1:
# If odd-index run (1-based), watch everything.
return "DebugIdentity", r".*", r".*"
else:
# If even-index run, watch nothing.
return "DebugIdentity", r"$^", r"$^"
dumping_hook = hooks.DumpingDebugHook(
self.session_root, watch_fn=counting_watch_fn, log_usage=False)
mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
for _ in range(4):
mon_sess.run(self.inc_v)
dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
dump_dirs = sorted(
dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
self.assertEqual(4, len(dump_dirs))
for i, dump_dir in enumerate(dump_dirs):
self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
dump = debug_data.DebugDumpDir(dump_dir)
if i % 2 == 0:
self.assertAllClose([10.0 + 1.0 * i],
dump.get_tensors("v", 0, "DebugIdentity"))
else:
self.assertEqual(0, dump.size)
self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
self.assertEqual(repr(None), dump.run_feed_keys_info)
示例14: DISABLED_test_save_steps_saves_periodically
def DISABLED_test_save_steps_saves_periodically(self):
with self.graph.as_default():
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_steps=2, scaffold=self.scaffold)
hook.begin()
self.scaffold.finalize()
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(self.train_op)
mon_sess.run(self.train_op)
# Not saved
self.assertEqual(1,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
mon_sess.run(self.train_op)
# saved
self.assertEqual(3,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
mon_sess.run(self.train_op)
# Not saved
self.assertEqual(3,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
mon_sess.run(self.train_op)
# saved
self.assertEqual(5,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
示例15: DISABLED_test_save_secs_calls_listeners_periodically
def DISABLED_test_save_secs_calls_listeners_periodically(self):
with self.graph.as_default():
listener = MockCheckpointSaverListener()
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir,
save_secs=2,
scaffold=self.scaffold,
listeners=[listener])
hook.begin()
self.scaffold.finalize()
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(self.train_op) # hook runs here
mon_sess.run(self.train_op)
time.sleep(2.5)
mon_sess.run(self.train_op) # hook runs here
mon_sess.run(self.train_op)
mon_sess.run(self.train_op)
time.sleep(2.5)
mon_sess.run(self.train_op) # hook runs here
mon_sess.run(self.train_op) # hook won't run here, so it does at end
hook.end(sess) # hook runs here
self.assertEqual({
'begin': 1,
'before_save': 4,
'after_save': 4,
'end': 1
}, listener.get_counts())