本文整理汇总了Python中tensorflow.python.training.checkpoint_state_pb2.CheckpointState方法的典型用法代码示例。如果您正苦于以下问题:Python checkpoint_state_pb2.CheckpointState方法的具体用法?Python checkpoint_state_pb2.CheckpointState怎么用?Python checkpoint_state_pb2.CheckpointState使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.training.checkpoint_state_pb2
的用法示例。
在下文中一共展示了checkpoint_state_pb2.CheckpointState方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: ParseCheckpoint
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def ParseCheckpoint(checkpoint):
"""Parse a checkpoint file.
Args:
checkpoint: Path to checkpoint. The checkpoint is either a serialized
CheckpointState proto or an actual checkpoint file.
Returns:
The path to an actual checkpoint file.
"""
warnings.warn(
"ParseCheckpoint is deprecated. "
"Will be removed in DeepChem 1.4.", DeprecationWarning)
with open(checkpoint) as f:
try:
cp = checkpoint_state_pb2.CheckpointState()
text_format.Merge(f.read(), cp)
return cp.model_checkpoint_path
except text_format.ParseError:
return checkpoint
示例2: testParseCheckpoint
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def testParseCheckpoint(self):
# parse CheckpointState proto
with tempfile.NamedTemporaryFile(mode='w+') as f:
cp = checkpoint_state_pb2.CheckpointState()
cp.model_checkpoint_path = 'my-checkpoint'
f.write(text_format.MessageToString(cp))
f.file.flush()
self.assertEqual(utils.ParseCheckpoint(f.name), 'my-checkpoint')
# parse path to actual checkpoint
with tempfile.NamedTemporaryFile(mode='w+') as f:
f.write('This is not a CheckpointState proto.')
f.file.flush()
self.assertEqual(utils.ParseCheckpoint(f.name), f.name)
示例3: _GetState
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def _GetState(self):
"""Returns the latest checkpoint id."""
state = CheckpointState()
if file_io.file_exists(self._state_file):
content = file_io.read_file_to_string(self._state_file)
text_format.Merge(content, state)
return state
示例4: _GetCheckpointFilename
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def _GetCheckpointFilename(save_dir, latest_filename):
"""Returns a filename for storing the CheckpointState.
Args:
save_dir: The directory for saving and restoring checkpoints.
latest_filename: Name of the file in 'save_dir' that is used
to store the CheckpointState.
Returns:
The path of the file that contains the CheckpointState proto.
"""
if latest_filename is None:
latest_filename = "checkpoint"
return os.path.join(save_dir, latest_filename)
示例5: update_checkpoint_state
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
"""
_update_checkpoint_state(
save_dir=save_dir,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths,
latest_filename=latest_filename,
save_relative_paths=False)
示例6: update_checkpoint_state
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Raises:
RuntimeError: If the save paths conflict.
"""
# Writes the "checkpoint" file for the coordinator for later restoration.
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
ckpt = generate_checkpoint_state_proto(
save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
"checkpoint state. Please use a different save path." %
model_checkpoint_path)
# Preventing potential read/write race condition by *atomically* writing to a
# file.
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
text_format.MessageToString(ckpt))
示例7: _get_current_checkpoint
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def _get_current_checkpoint(self):
try:
checkpoint_metadata_filepath = os.path.abspath(
os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME))
checkpoint = CheckpointState()
if os.path.exists(checkpoint_metadata_filepath) == False:
return None
contents = open(checkpoint_metadata_filepath, 'r').read()
text_format.Merge(contents, checkpoint)
return checkpoint
except Exception as e:
print("Got exception while reading checkpoint metadata", e)
raise e
示例8: test_checkpoint_contains_relative_paths
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def test_checkpoint_contains_relative_paths(self):
tmpdir = tempfile.mkdtemp()
est = estimator.EstimatorV2(
model_dir=tmpdir, model_fn=model_fn_global_step_incrementer)
est.train(dummy_input_fn, steps=5)
checkpoint_file_content = file_io.read_file_to_string(
os.path.join(tmpdir, 'checkpoint'))
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
# TODO(b/78461127): Please modify tests to not directly rely on names of
# checkpoints.
self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'],
ckpt.all_model_checkpoint_paths)
示例9: generate_checkpoint_state_proto
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def generate_checkpoint_state_proto(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None):
"""Generates a checkpoint state proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
Returns:
CheckpointState proto with model_checkpoint_path and
all_model_checkpoint_paths updated to either absolute paths or
relative paths to the current save_dir.
"""
if all_model_checkpoint_paths is None:
all_model_checkpoint_paths = []
if (not all_model_checkpoint_paths or
all_model_checkpoint_paths[-1] != model_checkpoint_path):
logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
model_checkpoint_path)
all_model_checkpoint_paths.append(model_checkpoint_path)
# Relative paths need to be rewritten to be relative to the "save_dir"
# if model_checkpoint_path already contains "save_dir".
if not os.path.isabs(save_dir):
if not os.path.isabs(model_checkpoint_path):
model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
for i in range(len(all_model_checkpoint_paths)):
p = all_model_checkpoint_paths[i]
if not os.path.isabs(p):
all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
coord_checkpoint_proto = CheckpointState(
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
return coord_checkpoint_proto
示例10: _update_checkpoint_state
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def _update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None,
save_relative_paths=False):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
save_relative_paths: If `True`, will write relative paths to the checkpoint
state file.
Raises:
RuntimeError: If any of the model checkpoint paths conflict with the file
containing CheckpointSate.
"""
# Writes the "checkpoint" file for the coordinator for later restoration.
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
if save_relative_paths:
if os.path.isabs(model_checkpoint_path):
rel_model_checkpoint_path = os.path.relpath(
model_checkpoint_path, save_dir)
else:
rel_model_checkpoint_path = model_checkpoint_path
rel_all_model_checkpoint_paths = []
for p in all_model_checkpoint_paths:
if os.path.isabs(p):
rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
else:
rel_all_model_checkpoint_paths.append(p)
ckpt = generate_checkpoint_state_proto(
save_dir,
rel_model_checkpoint_path,
all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
else:
ckpt = generate_checkpoint_state_proto(
save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
if coord_checkpoint_filename == ckpt.model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
"checkpoint state. Please use a different save path." %
model_checkpoint_path)
# Preventing potential read/write race condition by *atomically* writing to a
# file.
file_io.atomic_write_string_to_file(coord_checkpoint_filename,
text_format.MessageToString(ckpt))
示例11: get_checkpoint_state
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
Args:
checkpoint_dir: The directory of checkpoints.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Returns:
A CheckpointState if the state was available, None
otherwise.
Raises:
ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
"""
ckpt = None
coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
latest_filename)
f = None
try:
# Check that the file exists before opening it to avoid
# many lines of errors from colossus in the logs.
if file_io.file_exists(coord_checkpoint_filename):
file_content = file_io.read_file_to_string(
coord_checkpoint_filename)
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path:
raise ValueError("Invalid checkpoint state loaded from %s",
checkpoint_dir)
# For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir.
if not os.path.isabs(ckpt.model_checkpoint_path):
ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
ckpt.model_checkpoint_path)
for i in range(len(ckpt.all_model_checkpoint_paths)):
p = ckpt.all_model_checkpoint_paths[i]
if not os.path.isabs(p):
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
except errors.OpError as e:
# It's ok if the file cannot be read
logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
except text_format.ParseError as e:
logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
finally:
if f:
f.close()
return ckpt
示例12: get_checkpoint_state
# 需要导入模块: from tensorflow.python.training import checkpoint_state_pb2 [as 别名]
# 或者: from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState [as 别名]
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
Args:
checkpoint_dir: The directory of checkpoints.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Returns:
A CheckpointState if the state was available, None
otherwise.
Raises:
ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
"""
ckpt = None
coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
latest_filename)
f = None
try:
# Check that the file exists before opening it to avoid
# many lines of errors from colossus in the logs.
if file_io.file_exists(coord_checkpoint_filename):
file_content = file_io.read_file_to_string(
coord_checkpoint_filename).decode("utf-8")
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path:
raise ValueError("Invalid checkpoint state loaded from %s",
checkpoint_dir)
# For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir.
if not os.path.isabs(ckpt.model_checkpoint_path):
ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
ckpt.model_checkpoint_path)
for i in range(len(ckpt.all_model_checkpoint_paths)):
p = ckpt.all_model_checkpoint_paths[i]
if not os.path.isabs(p):
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
except errors.OpError as e:
# It's ok if the file cannot be read
logging.warning(str(e))
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
except text_format.ParseError as e:
logging.warning(str(e))
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
finally:
if f:
f.close()
return ckpt