本文整理汇总了Python中tensorflow.core.protobuf.saver_pb2.SaverDef方法的典型用法代码示例。如果您正苦于以下问题:Python saver_pb2.SaverDef方法的具体用法?Python saver_pb2.SaverDef怎么用?Python saver_pb2.SaverDef使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.core.protobuf.saver_pb2
的用法示例。
在下文中一共展示了saver_pb2.SaverDef方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _prefix_to_checkpoint_path
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def _prefix_to_checkpoint_path(prefix, format_version):
"""Returns the pathname of a checkpoint file, given the checkpoint prefix.
For V1 checkpoint, simply returns the prefix itself (the data file). For V2,
returns the pathname to the index file.
Args:
prefix: a string, the prefix of a checkpoint.
format_version: the checkpoint format version that corresponds to the
prefix.
Returns:
The pathname of a checkpoint file, taking into account the checkpoint
format version.
"""
if format_version == saver_pb2.SaverDef.V2:
return prefix + ".index" # The index file identifies a checkpoint.
return prefix # Just the data file.
示例2: checkpoint_exists
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def checkpoint_exists(checkpoint_prefix):
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
This is the recommended way to check if a checkpoint exists, since it takes
into account the naming difference between V1 and V2 formats.
Args:
checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
Returns:
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
"""
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
saver_pb2.SaverDef.V2)
if file_io.get_matching_files(pathname):
return True
elif file_io.get_matching_files(checkpoint_prefix):
return True
else:
return False
示例3: main
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def main(unused_args, flags):
if flags.checkpoint_version == 1:
checkpoint_version = saver_pb2.SaverDef.V1
elif flags.checkpoint_version == 2:
checkpoint_version = saver_pb2.SaverDef.V2
else:
print("Invalid checkpoint version (must be '1' or '2'): %d" %
flags.checkpoint_version)
return -1
freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
flags.input_checkpoint, flags.output_node_names,
flags.restore_op_name, flags.filename_tensor_name,
flags.output_graph, flags.clear_devices, flags.initializer_nodes,
flags.variable_names_whitelist, flags.variable_names_blacklist,
flags.input_meta_graph, flags.input_saved_model_dir,
flags.saved_model_tags, checkpoint_version)
示例4: _parse_input_saver_proto
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def _parse_input_saver_proto(input_saver, input_binary):
"""Parser input tensorflow Saver into SaverDef proto."""
if not gfile.Exists(input_saver):
print("Input saver file '" + input_saver + "' does not exist!")
return -1
mode = "rb" if input_binary else "r"
with gfile.FastGFile(input_saver, mode) as f:
saver_def = saver_pb2.SaverDef()
if input_binary:
saver_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), saver_def)
return saver_def
示例5: __init__
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def __init__(self, write_version=saver_pb2.SaverDef.V2):
self._write_version = write_version
示例6: save_op
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def save_op(self, filename_tensor, saveables):
"""Create an Op to save 'saveables'.
This is intended to be overridden by subclasses that want to generate
different Ops.
Args:
filename_tensor: String Tensor.
saveables: A list of BaseSaverBuilder.SaveableObject objects.
Returns:
An Operation that save the variables.
Raises:
RuntimeError: (implementation detail) if "self._write_version" is an
unexpected value.
"""
# pylint: disable=protected-access
tensor_names = []
tensors = []
tensor_slices = []
for saveable in saveables:
for spec in saveable.specs:
tensor_names.append(spec.name)
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
if self._write_version == saver_pb2.SaverDef.V1:
return io_ops._save(
filename=filename_tensor,
tensor_names=tensor_names,
tensors=tensors,
tensor_slices=tensor_slices)
elif self._write_version == saver_pb2.SaverDef.V2:
# "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
# of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
tensors)
else:
raise RuntimeError("Unexpected write_version: " + self._write_version)
# pylint: disable=unused-argument
示例7: _AddShardedSaveOps
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def _AddShardedSaveOps(self, filename_tensor, per_device):
"""Add ops to save the params per shard.
Args:
filename_tensor: a scalar String Tensor.
per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
returned by _GroupByDevices().
Returns:
An op to save the variables.
"""
if self._write_version == saver_pb2.SaverDef.V2:
return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
num_shards = len(per_device)
sharded_saves = []
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
for shard, (device, saveables) in enumerate(per_device):
with ops.device(device):
sharded_filename = self.sharded_filename(filename_tensor, shard,
num_shards_tensor)
sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
# Return the sharded name for the save path.
with ops.control_dependencies([x.op for x in sharded_saves]):
# pylint: disable=protected-access
return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor)
示例8: _check_saver_def
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def _check_saver_def(self):
if not isinstance(self.saver_def, saver_pb2.SaverDef):
raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
self.saver_def)
if not self.saver_def.save_tensor_name:
raise ValueError("saver_def must specify the save_tensor_name: %s" %
str(self.saver_def))
if not self.saver_def.restore_op_name:
raise ValueError("saver_def must specify the restore_op_name: %s" %
str(self.saver_def))
示例9: to_proto
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def to_proto(self, export_scope=None):
"""Converts this `Saver` to a `SaverDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `SaverDef` protocol buffer.
"""
if export_scope is None:
return self.saver_def
if not (self.saver_def.filename_tensor_name.startswith(export_scope) and
self.saver_def.save_tensor_name.startswith(export_scope) and
self.saver_def.restore_op_name.startswith(export_scope)):
return None
saver_def = saver_pb2.SaverDef()
saver_def.CopyFrom(self.saver_def)
saver_def.filename_tensor_name = ops.strip_name_scope(
saver_def.filename_tensor_name, export_scope)
saver_def.save_tensor_name = ops.strip_name_scope(
saver_def.save_tensor_name, export_scope)
saver_def.restore_op_name = ops.strip_name_scope(
saver_def.restore_op_name, export_scope)
return saver_def
示例10: latest_checkpoint
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def latest_checkpoint(checkpoint_dir, latest_filename=None):
"""Finds the filename of latest saved checkpoint file.
Args:
checkpoint_dir: Directory where the variables were saved.
latest_filename: Optional name for the protocol buffer file that
contains the list of most recent checkpoint filenames.
See the corresponding argument to `Saver.save()`.
Returns:
The full path to the latest checkpoint or `None` if no checkpoint was found.
"""
# Pick the latest checkpoint based on checkpoint state.
ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
if ckpt and ckpt.model_checkpoint_path:
# Look for either a V2 path or a V1 path, with priority for V2.
v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V2)
v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V1)
if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
v1_path):
return ckpt.model_checkpoint_path
else:
logging.error("Couldn't match files for checkpoint %s",
ckpt.model_checkpoint_path)
return None
示例11: to_proto
# 需要导入模块: from tensorflow.core.protobuf import saver_pb2 [as 别名]
# 或者: from tensorflow.core.protobuf.saver_pb2 import SaverDef [as 别名]
def to_proto(self, export_scope=None):
"""Converts this `Saver` to a `SaverDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `SaverDef` protocol buffer.
"""
if (export_scope is None or
self._name.startswith(export_scope)):
return self.saver_def
else:
return None