本文整理汇总了Python中tensorflow.python.training.saver.latest_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python latest_checkpoint函数的具体用法?Python latest_checkpoint怎么用?Python latest_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了latest_checkpoint函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testRecoverSession
def testRecoverSession(self):
# Create a checkpoint.
checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
try:
gfile.DeleteRecursively(checkpoint_dir)
except errors.OpError:
pass # Ignore
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
v = variables.Variable(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
sess, initialized = sm.recover_session(
"", saver=saver, checkpoint_dir=checkpoint_dir)
self.assertFalse(initialized)
sess.run(v.initializer)
self.assertEquals(1, sess.run(v))
saver.save(sess,
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
self._test_recovered_variable(
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
checkpoint_dir))
# Cannot set both checkpoint_dir and checkpoint_filename_with_path.
with self.assertRaises(ValueError):
self._test_recovered_variable(
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
checkpoint_dir))
示例2: _read_config_files
def _read_config_files(self, run_paths):
configs = {}
config_fpaths = {}
for run_name, logdir in run_paths.items():
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
if not file_io.file_exists(config_fpath):
# Skip runs that have no config file.
continue
# Read the config file.
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
config = ProjectorConfig()
text_format.Merge(file_content, config)
if not config.model_checkpoint_path:
# See if you can find a checkpoint file in the logdir.
ckpt_path = latest_checkpoint(logdir)
if not ckpt_path:
# Or in the parent of logdir.
ckpt_path = latest_checkpoint(os.path.join('../', logdir))
if not ckpt_path:
logging.warning('Cannot find model checkpoint in %s', logdir)
continue
config.model_checkpoint_path = ckpt_path
# Sanity check for the checkpoint file.
if not file_io.file_exists(config.model_checkpoint_path):
logging.warning('Checkpoint file %s not found',
config.model_checkpoint_path)
continue
configs[run_name] = config
config_fpaths[run_name] = config_fpath
return configs, config_fpaths
示例3: _find_latest_checkpoint
def _find_latest_checkpoint(dir_path):
try:
ckpt_path = latest_checkpoint(dir_path)
if not ckpt_path:
# Check the parent directory.
ckpt_path = latest_checkpoint(os.path.join(dir_path, os.pardir))
return ckpt_path
except errors.NotFoundError:
return None
示例4: testMultiEvalStepIncrements
def testMultiEvalStepIncrements(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')
# Train a model for a single step to get a checkpoint.
self._train_model(checkpoint_dir, num_steps=1)
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
# Create the model so we have something to restore.
inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
logistic_classifier(inputs)
num_evals = 6
my_var = local_variable(0.0, name='MyVar')
# In eval ops, we also increase the eval step one more time.
eval_ops = [state_ops.assign_add(my_var, 1.0),
state_ops.assign_add(
evaluation._get_or_create_eval_step(), 1, use_locking=True)]
expect_eval_update_counts = num_evals // 2
final_ops = array_ops.identity(my_var)
final_ops_values = evaluation._evaluate_once(
checkpoint_path=checkpoint_path,
eval_ops=eval_ops,
final_ops={'value': final_ops},
hooks=[evaluation._StopAfterNEvalsHook(num_evals),])
self.assertEqual(final_ops_values['value'], expect_eval_update_counts)
示例5: testEvalOpAndFinalOp
def testEvalOpAndFinalOp(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')
# Train a model for a single step to get a checkpoint.
self._train_model(checkpoint_dir, num_steps=1)
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
# Create the model so we have something to restore.
inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
logistic_classifier(inputs)
num_evals = 5
final_increment = 9.0
my_var = local_variable(0.0, name='MyVar')
eval_ops = state_ops.assign_add(my_var, 1.0)
final_ops = array_ops.identity(my_var) + final_increment
final_hooks = [evaluation._StopAfterNEvalsHook(num_evals),]
initial_hooks = list(final_hooks)
final_ops_values = evaluation._evaluate_once(
checkpoint_path=checkpoint_path,
eval_ops=eval_ops,
final_ops={'value': final_ops},
hooks=final_hooks)
self.assertEqual(final_ops_values['value'], num_evals + final_increment)
self.assertEqual(initial_hooks, final_hooks)
示例6: testEvaluateWithFiniteInputs
def testEvaluateWithFiniteInputs(self):
checkpoint_dir = os.path.join(self.get_temp_dir(),
'evaluate_with_finite_inputs')
# Train a Model to completion:
self._train_model(checkpoint_dir, num_steps=300)
# Run evaluation. Inputs are fed through input producer for one epoch.
all_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
all_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
single_input, single_label = training.slice_input_producer(
[all_inputs, all_labels], num_epochs=1)
inputs, labels = training.batch([single_input, single_label], batch_size=6,
allow_smaller_final_batch=True)
logits = logistic_classifier(inputs)
predictions = math_ops.round(logits)
accuracy, update_op = metrics.accuracy(
predictions=predictions, labels=labels)
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
final_ops_values = evaluation._evaluate_once(
checkpoint_path=checkpoint_path,
eval_ops=update_op,
final_ops={'accuracy': accuracy,
'eval_steps': evaluation._get_or_create_eval_step()},
hooks=[evaluation._StopAfterNEvalsHook(None),])
self.assertTrue(final_ops_values['accuracy'] > .99)
# Runs evaluation for 4 iterations. First 2 evaluate full batch of 6 inputs
# each; the 3rd iter evaluates the remaining 4 inputs, and the last one
# triggers an error which stops evaluation.
self.assertEqual(final_ops_values['eval_steps'], 4)
示例7: test_recovery
def test_recovery(self):
logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
scaffold = monitored_session.Scaffold()
# Use a hook to save the model every 100 steps. It also saves it at
# the end.
hooks = [
basic_session_run_hooks.CheckpointSaverHook(
logdir, save_steps=1, scaffold=scaffold)
]
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir),
hooks=hooks) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir)) as session:
self.assertEqual(2, session.run(gstep))
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
logdir))) as session:
self.assertEqual(2, session.run(gstep))
示例8: export_estimator
def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
signature_fn=_generic_signature_fn, default_batch_size=1,
exports_to_keep=None):
"""Exports inference graph into given dir.
Args:
estimator: Estimator to export
export_dir: A string containing a directory to write the exported graph
and checkpoints.
input_fn: Function that given `Tensor` of `Example` strings, parses it into
features that are then passed to the model.
signature_fn: Function that given `Tensor` of `Example` strings,
`dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
and returns default and named exporting signautres.
default_batch_size: Default batch size of the `Example` placeholder.
exports_to_keep: Number of exports to keep.
"""
checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
examples = array_ops.placeholder(dtype=dtypes.string,
shape=[default_batch_size],
name='input_example_tensor')
features = input_fn(estimator, examples)
predictions = estimator._get_predict_ops(features)
default_signature, named_graph_signatures = signature_fn(
examples, features, predictions)
if exports_to_keep is not None:
exports_to_keep = gc.largest_export_versions(exports_to_keep)
_export_graph(g, _get_saver(), checkpoint_path, export_dir,
default_graph_signature=default_signature,
named_graph_signatures=named_graph_signatures,
exports_to_keep=exports_to_keep)
示例9: _restore_or_save_initial_ckpt
def _restore_or_save_initial_ckpt(self, session):
# Ideally this should be run in after_create_session but is not for the
# following reason:
# Currently there is no way of enforcing an order of running the
# `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
# is run *after* this hook. That is troublesome because
# 1. If a checkpoint exists and this hook restores it, the initializer hook
# will override it.
# 2. If no checkpoint exists, this hook will try to save an initialized
# iterator which will result in an exception.
#
# As a temporary fix we enter the following implicit contract between this
# hook and the _DatasetInitializerHook.
# 1. The _DatasetInitializerHook initializes the iterator in the call to
# after_create_session.
# 2. This hook saves the iterator on the first call to `before_run()`, which
# is guaranteed to happen after `after_create_session()` of all hooks
# have been run.
# Check if there is an existing checkpoint. If so, restore from it.
# pylint: disable=protected-access
latest_checkpoint_path = saver_lib.latest_checkpoint(
self._checkpoint_saver_hook._checkpoint_dir,
latest_filename=self._latest_filename)
if latest_checkpoint_path:
self._checkpoint_saver_hook._get_saver().restore(session,
latest_checkpoint_path)
else:
# The checkpoint saved here is the state at step "global_step".
# Note: We do not save the GraphDef or MetaGraphDef here.
global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
self._checkpoint_saver_hook._save(session, global_step)
self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
示例10: _evaluate_model
def _evaluate_model(self,
input_fn,
steps,
feed_fn=None,
metrics=None,
name=''):
if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
return
checkpoint_path = saver.latest_checkpoint(self._model_dir)
eval_dir = os.path.join(self._model_dir, 'eval' if not name else
'eval_' + name)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, targets = input_fn()
self._check_inputs(features, targets)
eval_dict = self._get_eval_ops(features, targets,
metrics if metrics is not None else
self._get_default_metric_functions())
update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
eval_results, _ = evaluate(graph=g,
output_dir=eval_dir,
checkpoint_path=checkpoint_path,
eval_dict=eval_dict,
update_op=update_op,
global_step_tensor=global_step,
supervisor_master=self._config.master,
feed_fn=feed_fn,
max_steps=steps)
return eval_results
示例11: _infer_model
def _infer_model(
self, input_fn, feed_fn=None, outputs=None, as_iterable=False):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
predictions = self._get_predict_ops(features)
# If predictions is single output - wrap it into dict, and remember to
# return not a dict.
return_dict = isinstance(predictions, dict)
if not return_dict:
predictions = {'predictions': predictions}
# Filter what to run predictions on, if outputs provided.
if outputs:
existing_keys = predictions.keys()
predictions = {
key: value for key, value in predictions.items() if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
if as_iterable:
return self._infer_model_as_iterable(
checkpoint_path, predictions, feed_fn, return_dict)
else:
return self._infer_model_single(
checkpoint_path, predictions, feed_fn, return_dict)
示例12: _infer_model
def _infer_model(self,
x=None, input_fn=None, feed_fn=None,
batch_size=None, axis=None, proba=False):
# Converts inputs into tf.DataFrame / tf.Series.
batch_size = -1 if batch_size is None else batch_size
if x is not None:
input_fn, feed_fn = _get_predict_input_fn(x, batch_size)
checkpoint_path = saver.latest_checkpoint(self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features, _ = input_fn()
predictions = self._get_predict_ops(features)
if not isinstance(predictions, dict):
predictions = {'predictions': predictions}
# TODO(ipolosukhin): Support batching
if feed_fn is None:
return infer(checkpoint_path, predictions)
preds = {}
while True:
try:
feed_dict = feed_fn()
except StopIteration:
break
if feed_dict is None:
break
outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
for key in outputs:
if key not in preds:
preds[key] = []
preds[key].append(outputs[key])
for key in preds:
preds[key] = np.concatenate(preds[key], axis=0)
return preds
示例13: testUsageGraph
def testUsageGraph(self):
"""Expected usage when graph building."""
with context.graph_mode():
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
for training_continuation in range(3):
with ops.Graph().as_default():
network = MyNetwork()
optimizer = adam.AdamOptimizer(0.001)
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, network=network,
global_step=training_util.get_or_create_global_step())
input_value = constant_op.constant([[3.]])
train_op = optimizer.minimize(
network(input_value),
global_step=root.global_step)
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
if checkpoint_path is None:
self.assertEqual(0, training_continuation)
with self.assertRaises(AssertionError):
status.assert_consumed()
else:
status.assert_consumed()
for _ in range(num_training_steps):
session.run(train_op)
root.save(file_prefix=checkpoint_prefix, session=session)
self.assertEqual((training_continuation + 1) * num_training_steps,
session.run(root.global_step))
self.assertEqual(training_continuation + 1,
session.run(root.save_counter))
示例14: create_session
def create_session(self, checkpoint_dir):
"""Creates a MonitoredSession for this predictor."""
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
return training.MonitoredSession(
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
config=self._session_config()))
示例15: testDeferredRestorationUsageEager
def testDeferredRestorationUsageEager(self):
"""An idiomatic eager execution example."""
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
latest_object_graph = None # Will be saved with the checkpoint eventually.
for training_continuation in range(3):
with ops.Graph().as_default():
network = MyNetwork()
optimizer = CheckpointableAdam(0.001)
root = Root(optimizer=optimizer, network=network)
checkpointable.restore(
save_path=core_saver.latest_checkpoint(checkpoint_directory),
root_checkpointable=root,
object_graph_proto=latest_object_graph)
for _ in range(num_training_steps):
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
input_value = constant_op.constant([[3.]])
optimizer.minimize(
lambda: network(input_value), # pylint: disable=cell-var-from-loop
global_step=root.global_step)
latest_object_graph, _ = checkpointable.save(
file_prefix=checkpoint_prefix,
root_checkpointable=root)
self.assertEqual((training_continuation + 1) * num_training_steps,
root.global_step.numpy())