本文整理汇总了Python中tensorflow.contrib.framework.python.framework.checkpoint_utils.load_variable函数的典型用法代码示例。如果您正苦于以下问题:Python load_variable函数的具体用法?Python load_variable怎么用?Python load_variable使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了load_variable函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_save_steps_saves_periodically
def test_save_steps_saves_periodically(self):
with self.graph.as_default():
monitor = learn.monitors.CheckpointSaver(
self.model_dir, save_steps=2, scaffold=self.scaffold)
monitor.begin()
self.scaffold.finalize()
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
self._run(monitor, 1, self.train_op, sess)
self._run(monitor, 2, self.train_op, sess)
# Not saved
self.assertEqual(1,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
self._run(monitor, 3, self.train_op, sess)
# saved
self.assertEqual(3,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
self._run(monitor, 4, self.train_op, sess)
# Not saved
self.assertEqual(3,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
self._run(monitor, 5, self.train_op, sess)
# saved
self.assertEqual(5,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
示例2: test_save_steps_saves_periodically
def test_save_steps_saves_periodically(self):
with self.graph.as_default():
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_steps=2, scaffold=self.scaffold)
hook.begin()
self.scaffold.finalize()
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(self.train_op)
mon_sess.run(self.train_op)
# Not saved
self.assertEqual(1,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
mon_sess.run(self.train_op)
# saved
self.assertEqual(3,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
mon_sess.run(self.train_op)
# Not saved
self.assertEqual(3,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
mon_sess.run(self.train_op)
# saved
self.assertEqual(5,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
示例3: test_train_max_steps_is_not_incremental
def test_train_max_steps_is_not_incremental(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=15)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(15, step)
示例4: test_train_skip_train_if_max_step_already_saved
def test_train_skip_train_if_max_step_already_saved(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
示例5: testGetTensor
def testGetTensor(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var2"), v2)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var3"), v3)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "useful_scope/var4"), v4)
示例6: print_tensors_in_checkpoint_file
def print_tensors_in_checkpoint_file(file_name, tensor_name):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
in the checkpoint file.
If `tensor_name` is provided, prints the content of the tensor.
Args:
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
"""
try:
if not tensor_name:
variables = checkpoint_utils.list_variables(file_name)
for name, shape in variables:
print("%s\t%s" % (name, str(shape)))
else:
print("tensor_name: ", tensor_name)
print(checkpoint_utils.load_variable(file_name, tensor_name))
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
示例7: testNoTensor
def testNoTensor(self):
checkpoint_dir = self.get_temp_dir()
with self.test_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
with self.assertRaises(errors_impl.OpError):
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var5"), [])
示例8: test_save_secs_saves_periodically
def test_save_secs_saves_periodically(self, mock_time):
# Let's have a realistic start time
current_time = 1484695987.209386
with self.graph.as_default():
mock_time.return_value = current_time
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_secs=2, scaffold=self.scaffold)
hook.begin()
self.scaffold.finalize()
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
mock_time.return_value = current_time
mon_sess.run(self.train_op) # Saved.
mock_time.return_value = current_time + 0.5
mon_sess.run(self.train_op) # Not saved.
self.assertEqual(1,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
# Simulate 2.5 seconds of sleep.
mock_time.return_value = current_time + 2.5
mon_sess.run(self.train_op) # Saved.
mock_time.return_value = current_time + 2.6
mon_sess.run(self.train_op) # Not saved.
mock_time.return_value = current_time + 2.7
mon_sess.run(self.train_op) # Not saved.
self.assertEqual(3,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
# Simulate 7.5 more seconds of sleep (10 seconds from start.
mock_time.return_value = current_time + 10
mon_sess.run(self.train_op) # Saved.
self.assertEqual(6,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
示例9: test_saves_when_saver_and_scaffold_both_missing
def test_saves_when_saver_and_scaffold_both_missing(self):
with self.graph.as_default():
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_steps=1)
hook.begin()
self.scaffold.finalize()
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(self.train_op)
self.assertEqual(1,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
示例10: test_save_saves_at_end
def test_save_saves_at_end(self):
with self.graph.as_default():
monitor = learn.monitors.CheckpointSaver(
self.model_dir, save_secs=2, scaffold=self.scaffold)
monitor.begin()
self.scaffold.finalize()
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
self._run(monitor, 1, self.train_op, sess)
self._run(monitor, 2, self.train_op, sess)
monitor.end(sess)
self.assertEqual(2,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
示例11: covariances
def covariances(self):
"""Returns the covariances."""
return checkpoint_utils.load_variable(
self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
示例12: clusters
def clusters(self):
"""Returns cluster centers."""
clusters = checkpoint_utils.load_variable(
self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
return np.squeeze(clusters, 1)
示例13: weights
def weights(self):
"""Returns the cluster weights."""
return checkpoint_utils.load_variable(
self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
示例14: load_variable
def load_variable(checkpoint_dir, name):
"""See `tf.contrib.framework.load_variable`."""
return checkpoint_utils.load_variable(checkpoint_dir, name)
示例15: testNoCheckpoints
def testNoCheckpoints(self):
checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
with self.assertRaises(errors_impl.OpError):
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var1"), [])