本文整理汇总了Python中tensorflow.python.training.checkpoint_management.latest_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python latest_checkpoint函数的具体用法?Python latest_checkpoint怎么用?Python latest_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了latest_checkpoint函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testRecoverSession
def testRecoverSession(self):
# Create a checkpoint.
checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
try:
gfile.DeleteRecursively(checkpoint_dir)
except errors.OpError:
pass # Ignore
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
v = variables.Variable(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
sess, initialized = sm.recover_session(
"", saver=saver, checkpoint_dir=checkpoint_dir)
self.assertFalse(initialized)
sess.run(v.initializer)
self.assertEquals(1, sess.run(v))
saver.save(sess,
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
self._test_recovered_variable(
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
checkpoint_dir))
# Cannot set both checkpoint_dir and checkpoint_filename_with_path.
with self.assertRaises(ValueError):
self._test_recovered_variable(
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
checkpoint_dir))
示例2: _new_layer_weight_loading_test_template
def _new_layer_weight_loading_test_template(
self, first_model_fn, second_model_fn, restore_init_fn):
with self.cached_session() as session:
model = first_model_fn()
temp_dir = self.get_temp_dir()
prefix = os.path.join(temp_dir, 'ckpt')
x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
executing_eagerly = context.executing_eagerly()
ref_y_tensor = model(x)
if not executing_eagerly:
session.run([v.initializer for v in model.variables])
ref_y = self.evaluate(ref_y_tensor)
model.save_weights(prefix)
self.assertEqual(
prefix,
checkpoint_management.latest_checkpoint(temp_dir))
for v in model.variables:
self.evaluate(
v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
self.addCleanup(shutil.rmtree, temp_dir)
second_model = second_model_fn()
second_model.load_weights(prefix)
second_model(x)
self.evaluate(restore_init_fn(second_model))
second_model.save_weights(prefix)
# Check that the second model's checkpoint loads into the original model
model.load_weights(prefix)
y = self.evaluate(model(x))
self.assertAllClose(ref_y, y)
示例3: testUsageGraph
def testUsageGraph(self):
"""Expected usage when graph building."""
with context.graph_mode():
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
for training_continuation in range(3):
with ops.Graph().as_default():
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
input_value = constant_op.constant([[3.]])
train_op = optimizer.minimize(
model(input_value),
global_step=root.global_step)
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
with self.session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
if checkpoint_path is None:
self.assertEqual(0, training_continuation)
with self.assertRaises(AssertionError):
status.assert_consumed()
else:
status.assert_consumed()
for _ in range(num_training_steps):
session.run(train_op)
root.save(file_prefix=checkpoint_prefix, session=session)
self.assertEqual((training_continuation + 1) * num_training_steps,
session.run(root.global_step))
self.assertEqual(training_continuation + 1,
session.run(root.save_counter))
示例4: _restore_or_save_initial_ckpt
def _restore_or_save_initial_ckpt(self, session):
# Ideally this should be run in after_create_session but is not for the
# following reason:
# Currently there is no way of enforcing an order of running the
# `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
# is run *after* this hook. That is troublesome because
# 1. If a checkpoint exists and this hook restores it, the initializer hook
# will override it.
# 2. If no checkpoint exists, this hook will try to save an uninitialized
# iterator which will result in an exception.
#
# As a temporary fix we enter the following implicit contract between this
# hook and the _DatasetInitializerHook.
# 1. The _DatasetInitializerHook initializes the iterator in the call to
# after_create_session.
# 2. This hook saves the iterator on the first call to `before_run()`, which
# is guaranteed to happen after `after_create_session()` of all hooks
# have been run.
# Check if there is an existing checkpoint. If so, restore from it.
# pylint: disable=protected-access
latest_checkpoint_path = checkpoint_management.latest_checkpoint(
self._checkpoint_saver_hook._checkpoint_dir,
latest_filename=self._latest_filename)
if latest_checkpoint_path:
self._checkpoint_saver_hook._get_saver().restore(session,
latest_checkpoint_path)
else:
# The checkpoint saved here is the state at step "global_step".
# Note: We do not save the GraphDef or MetaGraphDef here.
global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
self._checkpoint_saver_hook._save(session, global_step)
self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
示例5: export_fn
def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None):
"""Exports the given Estimator as a SavedModel.
Args:
estimator: the Estimator to export.
export_dir_base: A string containing a directory to write the exported
graph and checkpoints.
checkpoint_path: The checkpoint path to export. If None (the default),
the most recent checkpoint found within the model directory is chosen.
eval_result: placehold args matching the call signature of ExportStrategy.
Returns:
The string path to the exported directory.
"""
if not checkpoint_path:
# TODO(b/67425018): switch to
# checkpoint_path = estimator.latest_checkpoint()
# as soon as contrib is cleaned up and we can thus be sure that
# estimator is a tf.estimator.Estimator and not a
# tf.contrib.learn.Estimator
checkpoint_path = checkpoint_management.latest_checkpoint(
estimator.model_dir)
export_checkpoint_path, export_eval_result = best_model_selector.update(
checkpoint_path, eval_result)
if export_checkpoint_path and export_eval_result is not None:
checkpoint_base = os.path.basename(export_checkpoint_path)
export_dir = os.path.join(export_dir_base, checkpoint_base)
return best_model_export_strategy.export(
estimator, export_dir, export_checkpoint_path, export_eval_result)
else:
return ''
示例6: testAgnosticUsage
def testAgnosticUsage(self):
"""Graph/eager agnostic usage."""
# Does create garbage when executing eagerly due to ops.Graph() creation.
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
for training_continuation in range(3):
with ops.Graph().as_default(), self.test_session(
graph=ops.get_default_graph()), test_util.device(use_gpu=True):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = checkpoint_management.latest_checkpoint(
checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
optimizer.minimize,
functools.partial(model, input_value),
global_step=root.global_step)
if not context.executing_eagerly():
train_fn = functools.partial(self.evaluate, train_fn())
status.initialize_or_restore()
for _ in range(num_training_steps):
train_fn()
root.save(file_prefix=checkpoint_prefix)
self.assertEqual((training_continuation + 1) * num_training_steps,
self.evaluate(root.global_step))
self.assertEqual(training_continuation + 1,
self.evaluate(root.save_counter))
示例7: wait_for_new_checkpoint
def wait_for_new_checkpoint(checkpoint_dir,
last_checkpoint=None,
seconds_to_sleep=1,
timeout=None):
"""Waits until a new checkpoint file is found.
Args:
checkpoint_dir: The directory in which checkpoints are saved.
last_checkpoint: The last checkpoint path used or `None` if we're expecting
a checkpoint for the first time.
seconds_to_sleep: The number of seconds to sleep for before looking for a
new checkpoint.
timeout: The maximum number of seconds to wait. If left as `None`, then the
process will wait indefinitely.
Returns:
a new checkpoint path, or None if the timeout was reached.
"""
logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
stop_time = time.time() + timeout if timeout is not None else None
while True:
checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir)
if checkpoint_path is None or checkpoint_path == last_checkpoint:
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
return None
time.sleep(seconds_to_sleep)
else:
logging.info('Found new checkpoint at %s', checkpoint_path)
return checkpoint_path
示例8: testGraphDistributionStrategy
def testGraphDistributionStrategy(self):
self.skipTest("b/121381184")
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
def _train_fn(optimizer, model):
input_value = constant_op.constant([[3.]])
return optimizer.minimize(
functools.partial(model, input_value),
global_step=root.optimizer_step)
for training_continuation in range(3):
with ops.Graph().as_default():
strategy = mirrored_strategy.MirroredStrategy()
with strategy.scope():
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
status = root.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory))
train_op = strategy.extended.call_for_each_replica(
functools.partial(_train_fn, optimizer, model))
with self.session() as session:
if training_continuation > 0:
status.assert_consumed()
status.initialize_or_restore()
for _ in range(num_training_steps):
session.run(train_op)
root.save(file_prefix=checkpoint_prefix)
self.assertEqual((training_continuation + 1) * num_training_steps,
root.optimizer_step.numpy())
示例9: testEagerTPUDistributionStrategy
def testEagerTPUDistributionStrategy(self):
self.skipTest("b/121387144")
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
def _train_fn(optimizer, model):
input_value = constant_op.constant([[3.]])
optimizer.minimize(
functools.partial(model, input_value),
global_step=root.optimizer_step)
for training_continuation in range(3):
strategy = tpu_strategy.TPUStrategy()
with strategy.scope():
model = Subclassed()
optimizer = adam_v1.AdamOptimizer(0.001)
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
root.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory))
for _ in range(num_training_steps):
strategy.extended.call_for_each_replica(
functools.partial(_train_fn, optimizer, model))
root.save(file_prefix=checkpoint_prefix)
self.assertEqual((training_continuation + 1) * num_training_steps,
root.optimizer_step.numpy())
示例10: _read_vars
def _read_vars(self, model_dir):
"""Returns (global_step, latest_feature)."""
with ops.Graph().as_default() as g:
ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
meta_filename = ckpt_path + '.meta'
saver_lib.import_meta_graph(meta_filename)
saver = saver_lib.Saver()
with self.test_session(graph=g) as sess:
saver.restore(sess, ckpt_path)
return sess.run(ops.get_collection('my_vars'))
示例11: testRestoreInReconstructedIterator
def testRestoreInReconstructedIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
dataset = Dataset.range(10)
for i in range(5):
iterator = datasets.Iterator(dataset)
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory))
for j in range(2):
self.assertEqual(i * 2 + j, iterator.get_next().numpy())
checkpoint.save(file_prefix=checkpoint_prefix)
示例12: testRestoreInReconstructedIteratorInitializable
def testRestoreInReconstructedIteratorInitializable(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
dataset = dataset_ops.Dataset.range(10)
iterator = dataset.make_initializable_iterator()
get_next = iterator.get_next()
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
for i in range(5):
with self.cached_session() as sess:
checkpoint.restore(checkpoint_management.latest_checkpoint(
checkpoint_directory)).initialize_or_restore(sess)
for j in range(2):
self.assertEqual(i * 2 + j, self.evaluate(get_next))
checkpoint.save(file_prefix=checkpoint_prefix)
示例13: testRestoreInReconstructedIteratorInitializable
def testRestoreInReconstructedIteratorInitializable(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
dataset = dataset_ops.Dataset.range(10)
iterator = iter(dataset) if context.executing_eagerly(
) else dataset_ops.make_initializable_iterator(dataset)
get_next = iterator.get_next
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
for i in range(5):
checkpoint.restore(
checkpoint_management.latest_checkpoint(
checkpoint_directory)).initialize_or_restore()
for j in range(2):
self.assertEqual(i * 2 + j, self.evaluate(get_next()))
checkpoint.save(file_prefix=checkpoint_prefix)
示例14: testNameCollision
def testNameCollision(self):
# Make sure we have a clean directory to work in.
with self.tempDir() as tempdir:
# Jump to that directory until this test is done.
with self.tempWorkingDir(tempdir):
# Save training snapshots to a relative path.
traindir = "train/"
os.mkdir(traindir)
# Collides with the default name of the checkpoint state file.
filepath = os.path.join(traindir, "checkpoint")
with self.test_session() as sess:
unused_a = variables.Variable(0.0) # So that Saver saves something.
variables.global_variables_initializer().run()
# Should fail.
saver = saver_module.Saver(sharded=False)
with self.assertRaisesRegexp(ValueError, "collides with"):
saver.save(sess, filepath)
# Succeeds: the file will be named "checkpoint-<step>".
saver.save(sess, filepath, global_step=1)
self.assertIsNotNone(
checkpoint_management.latest_checkpoint(traindir))
# Succeeds: the file will be named "checkpoint-<i>-of-<n>".
saver = saver_module.Saver(sharded=True)
saver.save(sess, filepath)
self.assertIsNotNone(
checkpoint_management.latest_checkpoint(traindir))
# Succeeds: the file will be named "checkpoint-<step>-<i>-of-<n>".
saver = saver_module.Saver(sharded=True)
saver.save(sess, filepath, global_step=1)
self.assertIsNotNone(
checkpoint_management.latest_checkpoint(traindir))
示例15: __init__
def __init__(self,
estimator,
prediction_input_fn,
input_alternative_key=None,
output_alternative_key=None,
graph=None,
config=None):
"""Initialize a `ContribEstimatorPredictor`.
Args:
estimator: an instance of `tf.contrib.learn.Estimator`.
prediction_input_fn: a function that takes no arguments and returns an
instance of `InputFnOps`.
input_alternative_key: Optional. Specify the input alternative used for
prediction.
output_alternative_key: Specify the output alternative used for
prediction. Not needed for single-headed models but required for
multi-headed models.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
config: `ConfigProto` proto used to configure the session.
"""
self._graph = graph or ops.Graph()
with self._graph.as_default():
input_fn_ops = prediction_input_fn()
# pylint: disable=protected-access
model_fn_ops = estimator._get_predict_ops(input_fn_ops.features)
# pylint: enable=protected-access
checkpoint_path = checkpoint_management.latest_checkpoint(
estimator.model_dir)
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
config=config,
checkpoint_filename_with_path=checkpoint_path))
input_alternative_key = (
input_alternative_key or
saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY)
input_alternatives, _ = saved_model_export_utils.get_input_alternatives(
input_fn_ops)
self._feed_tensors = input_alternatives[input_alternative_key]
(output_alternatives,
output_alternative_key) = saved_model_export_utils.get_output_alternatives(
model_fn_ops, output_alternative_key)
_, fetch_tensors = output_alternatives[output_alternative_key]
self._fetch_tensors = fetch_tensors