本文整理匯總了Python中tensorflow.python.training.saver.latest_checkpoint方法的典型用法代碼示例。如果您正苦於以下問題:Python saver.latest_checkpoint方法的具體用法?Python saver.latest_checkpoint怎麽用?Python saver.latest_checkpoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類tensorflow.python.training.saver
的用法示例。
在下文中一共展示了saver.latest_checkpoint方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: _latest_checkpoints_changed
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _latest_checkpoints_changed(configs, run_path_pairs):
"""Returns true if the latest checkpoint has changed in any of the runs."""
for run_name, logdir in run_path_pairs:
if run_name not in configs:
config = ProjectorConfig()
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
if file_io.file_exists(config_fpath):
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
text_format.Merge(file_content, config)
else:
config = configs[run_name]
# See if you can find a checkpoint file in the logdir.
ckpt_path = latest_checkpoint(logdir)
if not ckpt_path:
# See if you can find a checkpoint in the parent of logdir.
ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir))
if not ckpt_path:
continue
if config.model_checkpoint_path != ckpt_path:
return True
return False
示例2: end
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def end(self, session=None):
super(ExportMonitor, self).end(session=session)
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
if latest_path is None:
logging.info("Skipping export at the end since model has not been saved "
"yet.")
return
try:
self._last_export_dir = self._estimator.export(
self.export_dir,
exports_to_keep=self.exports_to_keep,
signature_fn=self.signature_fn,
input_fn=self._input_fn,
default_batch_size=self._default_batch_size,
input_feature_key=self._input_feature_key,
use_deprecated_input_fn=self._use_deprecated_input_fn)
except RuntimeError:
logging.info("Skipping exporting for the same step.")
示例3: _latest_checkpoints_changed
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _latest_checkpoints_changed(configs, run_path_pairs):
"""Returns true if the latest checkpoint has changed in any of the runs."""
for run_name, logdir in run_path_pairs:
if run_name not in configs:
continue
config = configs[run_name]
if not config.model_checkpoint_path:
continue
# See if you can find a checkpoint file in the logdir.
ckpt_path = latest_checkpoint(logdir)
if not ckpt_path:
# See if you can find a checkpoint in the parent of logdir.
ckpt_path = latest_checkpoint(os.path.join('../', logdir))
if not ckpt_path:
continue
if config.model_checkpoint_path != ckpt_path:
return True
return False
示例4: _load_global_step_from_checkpoint_dir
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _load_global_step_from_checkpoint_dir(checkpoint_dir):
try:
checkpoint_reader = training.NewCheckpointReader(
training.latest_checkpoint(checkpoint_dir))
return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
except: # pylint: disable=bare-except
return 0
示例5: _get_checkpoint_filename
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _get_checkpoint_filename(ckpt_dir_or_file):
"""Returns checkpoint filename given directory or specific checkpoint file."""
if gfile.IsDirectory(ckpt_dir_or_file):
return saver.latest_checkpoint(ckpt_dir_or_file)
return ckpt_dir_or_file
示例6: _find_latest_checkpoint
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _find_latest_checkpoint(dir_path):
try:
ckpt_path = latest_checkpoint(dir_path)
if not ckpt_path:
# Check the parent directory.
ckpt_path = latest_checkpoint(os.path.join(dir_path, os.pardir))
return ckpt_path
except errors.NotFoundError:
return None
示例7: _infer_model
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _infer_model(self,
input_fn,
feed_fn=None,
outputs=None,
as_iterable=True,
iterate_batches=False):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
infer_ops = self._get_predict_ops(features)
predictions = self._filter_predictions(infer_ops.predictions, outputs)
mon_sess = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
scaffold=infer_ops.scaffold,
config=self._session_config))
if not as_iterable:
with mon_sess:
if not mon_sess.should_stop():
return mon_sess.run(predictions, feed_fn() if feed_fn else None)
else:
return self._predict_generator(mon_sess, predictions, feed_fn,
iterate_batches)
示例8: wait_for_new_checkpoint
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def wait_for_new_checkpoint(checkpoint_dir,
last_checkpoint=None,
seconds_to_sleep=1,
timeout=None):
"""Waits until a new checkpoint file is found.
Args:
checkpoint_dir: The directory in which checkpoints are saved.
last_checkpoint: The last checkpoint path used or `None` if we're expecting
a checkpoint for the first time.
seconds_to_sleep: The number of seconds to sleep for before looking for a
new checkpoint.
timeout: The maximum amount of time to wait. If left as `None`, then the
process will wait indefinitely.
Returns:
a new checkpoint path, or None if the timeout was reached.
"""
logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
stop_time = time.time() + timeout if timeout is not None else None
while True:
checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
if checkpoint_path is None or checkpoint_path == last_checkpoint:
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
return None
time.sleep(seconds_to_sleep)
else:
logging.info('Found new checkpoint at %s', checkpoint_path)
return checkpoint_path
示例9: _get_checkpoint_filename
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _get_checkpoint_filename(filepattern):
"""Returns checkpoint filename given directory or specific filepattern."""
if gfile.IsDirectory(filepattern):
return saver.latest_checkpoint(filepattern)
return filepattern
示例10: _read_latest_config_files
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _read_latest_config_files(self, run_path_pairs):
"""Reads and returns the projector config files in every run directory."""
configs = {}
config_fpaths = {}
for run_name, logdir in run_path_pairs:
config = ProjectorConfig()
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
if file_io.file_exists(config_fpath):
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
text_format.Merge(file_content, config)
has_tensor_files = False
for embedding in config.embeddings:
if embedding.tensor_path:
has_tensor_files = True
break
if not config.model_checkpoint_path:
# See if you can find a checkpoint file in the logdir.
ckpt_path = latest_checkpoint(logdir)
if not ckpt_path:
# Or in the parent of logdir.
ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir))
if not ckpt_path and not has_tensor_files:
continue
if ckpt_path:
config.model_checkpoint_path = ckpt_path
# Sanity check for the checkpoint file.
if (config.model_checkpoint_path and
not checkpoint_exists(config.model_checkpoint_path)):
logging.warning('Checkpoint file %s not found',
config.model_checkpoint_path)
continue
configs[run_name] = config
config_fpaths[run_name] = config_fpath
return configs, config_fpaths
示例11: _infer_model
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _infer_model(self,
input_fn,
feed_fn=None,
outputs=None,
as_iterable=True,
iterate_batches=False):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
infer_ops = self._call_legacy_get_predict_ops(features)
predictions = self._filter_predictions(infer_ops.predictions, outputs)
mon_sess = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path))
if not as_iterable:
with mon_sess:
if not mon_sess.should_stop():
return mon_sess.run(predictions, feed_fn() if feed_fn else None)
else:
return self._predict_generator(mon_sess, predictions, feed_fn,
iterate_batches)
示例12: _read_latest_config_files
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _read_latest_config_files(self, run_path_pairs):
"""Reads and returns the projector config files in every run directory."""
configs = {}
config_fpaths = {}
for run_name, logdir in run_path_pairs:
config = ProjectorConfig()
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
if file_io.file_exists(config_fpath):
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
text_format.Merge(file_content, config)
has_tensor_files = False
for embedding in config.embeddings:
if embedding.tensor_path:
has_tensor_files = True
break
if not config.model_checkpoint_path:
# See if you can find a checkpoint file in the logdir.
ckpt_path = latest_checkpoint(logdir)
if not ckpt_path:
# Or in the parent of logdir.
ckpt_path = latest_checkpoint(os.path.join('../', logdir))
if not ckpt_path and not has_tensor_files:
continue
if ckpt_path:
config.model_checkpoint_path = ckpt_path
# Sanity check for the checkpoint file.
if (config.model_checkpoint_path and
not checkpoint_exists(config.model_checkpoint_path)):
logging.warning('Checkpoint file %s not found',
config.model_checkpoint_path)
continue
configs[run_name] = config
config_fpaths[run_name] = config_fpath
return configs, config_fpaths
示例13: wait_for_new_checkpoint
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def wait_for_new_checkpoint(checkpoint_dir,
last_checkpoint,
seconds_to_sleep=1,
timeout=None):
"""Waits until a new checkpoint file is found.
Args:
checkpoint_dir: The directory in which checkpoints are saved.
last_checkpoint: The last checkpoint path used.
seconds_to_sleep: The number of seconds to sleep for before looking for a
new checkpoint.
timeout: The maximum amount of time to wait. If left as `None`, then the
process will wait indefinitely.
Returns:
a new checkpoint path, or None if the timeout was reached.
"""
logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
stop_time = time.time() + timeout if timeout is not None else None
while True:
checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
if checkpoint_path is None or checkpoint_path == last_checkpoint:
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
return None
time.sleep(seconds_to_sleep)
else:
logging.info('Found new checkpoint at %s', checkpoint_path)
return checkpoint_path
示例14: _infer_model
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def _infer_model(
self, input_fn, feed_fn=None, outputs=None, as_iterable=True):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
predictions = self._get_predict_ops(features)
# If predictions is single output - wrap it into dict, and remember to
# return not a dict.
return_dict = isinstance(predictions, dict)
if not return_dict:
predictions = {'predictions': predictions}
# Filter what to run predictions on, if outputs provided.
if outputs:
existing_keys = predictions.keys()
predictions = {
key: value for key, value in predictions.items() if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
if as_iterable:
return self._infer_model_as_iterable(
checkpoint_path, predictions, feed_fn, return_dict)
else:
return self._infer_model_single(
checkpoint_path, predictions, feed_fn, return_dict)
示例15: after_create_session
# 需要導入模塊: from tensorflow.python.training import saver [as 別名]
# 或者: from tensorflow.python.training.saver import latest_checkpoint [as 別名]
def after_create_session(self, session, coord):
""" Initializes all global variables after session is created.
Args:
session: A TensorFlow Session that has been created.
coord: A Coordinator object which keeps track of all threads.
"""
checkpoint_path = saver_lib.latest_checkpoint(self._checkpoint_dir)
if not checkpoint_path:
tf.logging.info(
"InitVariablesHook (after_create_sess): initializing all global variables...")
session.run(self._global_var_initop)