本文整理汇总了Python中tensorflow.python.ops.io_ops.save_v2方法的典型用法代码示例。如果您正苦于以下问题:Python io_ops.save_v2方法的具体用法?Python io_ops.save_v2怎么用?Python io_ops.save_v2使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.python.ops.io_ops
的用法示例。
在下文中一共展示了io_ops.save_v2方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from tensorflow.python.ops import io_ops [as 别名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 别名]
def main(argv):
del argv # Unused.
dataset = getattr(tf.keras.datasets, FLAGS.dataset)
(x_train, y_train), (x_test, y_test) = dataset.load_data()
def wrap(val):
dtype = tf.as_dtype(val.dtype)
assert dtype != tf.string # tf.string is not supported by py_func.
return tf.py_func(lambda: val, [], dtype)
out_prefix = FLAGS.out or os.path.join("/tmp", FLAGS.dataset, FLAGS.dataset)
tf.logging.info("Save %s dataset to %s ckpt." %
(FLAGS.dataset, out_prefix))
with tf.Session() as sess:
sess.run(
io_ops.save_v2(
prefix=out_prefix,
tensor_names=["x_train", "y_train", "x_test", "y_test"],
shape_and_slices=[""] * 4,
tensors=[wrap(x_train),
wrap(y_train),
wrap(x_test),
wrap(y_test)]))
示例2: FakeMnistData
# 需要导入模块: from tensorflow.python.ops import io_ops [as 别名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 别名]
def FakeMnistData(tmpdir, train_size=60000, test_size=10000):
"""Fake Mnist data for unit tests."""
data_path = os.path.join(tmpdir, 'ckpt')
with tf.Graph().as_default():
with tf.Session() as sess:
x_train = tf.ones((train_size, 28, 28, 1), dtype=tf.uint8)
y_train = tf.ones((train_size), dtype=tf.uint8)
x_test = tf.ones((test_size, 28, 28, 1), dtype=tf.uint8)
y_test = tf.ones((test_size), dtype=tf.uint8)
sess.run(
io_ops.save_v2(data_path, ['x_train', 'y_train', 'x_test', 'y_test'],
[''] * 4, [x_train, y_train, x_test, y_test]))
return data_path
示例3: _BuildSave
# 需要导入模块: from tensorflow.python.ops import io_ops [as 别名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 别名]
def _BuildSave(self):
"""Builds save ops."""
self._save_global_step = py_utils.GetGlobalStep()
self._save_prefix = tf.strings.join([
self._logdir_ph, "/ckpt-",
tf.as_string(self._save_global_step, width=8, fill="0")
])
self._save_op = io_ops.save_v2(
prefix=self._save_prefix,
tensor_names=[_VarKey(v) for v in self._vars],
tensors=[v.read_value() for v in self._vars],
shape_and_slices=[""] * len(self._vars))
示例4: WriteNpArrays
# 需要导入模块: from tensorflow.python.ops import io_ops [as 别名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 别名]
def WriteNpArrays(file_prefix, nmap):
"""Writes a NestedMap of numpy arrays into a TF checkpoint.
Args:
file_prefix: A TF checkpoint filename prefix.
nmap: A NestedMap of numpy arrays.
"""
g = tf.Graph()
with g.as_default():
def Wrap(val):
dtype = tf.as_dtype(val.dtype)
assert dtype != tf.string # tf.string is not supported by py_func.
return tf.py_func(lambda: val, [], dtype)
names, values = [], []
for k, v in nmap.FlattenItems():
names.append(k)
assert isinstance(v, np.ndarray)
values.append(Wrap(v))
save = io_ops.save_v2(
prefix=file_prefix,
tensor_names=names,
tensors=values,
shape_and_slices=[""] * len(names))
with tf.Session(graph=g) as sess:
sess.run(save)
示例5: save_op
# 需要导入模块: from tensorflow.python.ops import io_ops [as 别名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 别名]
def save_op(self, filename_tensor, saveables):
"""Create an Op to save 'saveables'.
This is intended to be overridden by subclasses that want to generate
different Ops.
Args:
filename_tensor: String Tensor.
saveables: A list of BaseSaverBuilder.SaveableObject objects.
Returns:
An Operation that save the variables.
Raises:
RuntimeError: (implementation detail) if "self._write_version" is an
unexpected value.
"""
# pylint: disable=protected-access
tensor_names = []
tensors = []
tensor_slices = []
for saveable in saveables:
for spec in saveable.specs:
tensor_names.append(spec.name)
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
if self._write_version == saver_pb2.SaverDef.V1:
return io_ops._save(
filename=filename_tensor,
tensor_names=tensor_names,
tensors=tensors,
tensor_slices=tensor_slices)
elif self._write_version == saver_pb2.SaverDef.V2:
# "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
# of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
tensors)
else:
raise RuntimeError("Unexpected write_version: " + self._write_version)
# pylint: disable=unused-argument
示例6: save_op
# 需要导入模块: from tensorflow.python.ops import io_ops [as 别名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 别名]
def save_op(self, filename_tensor, saveables):
tensor_names = []
tensors = []
tensor_slices = []
for saveable in saveables:
for spec in saveable.specs:
if spec.name.startswith('replicated_'):
if spec.name.startswith('replicated_0') or 'avg' in spec.name:
tensor_names.append('/'.join(spec.name.split('/')[1:]))
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
else:
tensor_names.append(spec.name)
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
if self._write_version == saver_pb2.SaverDef.V1:
return io_ops._save(
filename=filename_tensor,
tensor_names=tensor_names,
tensors=tensors,
tensor_slices=tensor_slices)
elif self._write_version == saver_pb2.SaverDef.V2:
# "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
# of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
tensors)
else:
raise RuntimeError("Unexpected write_version: " + self._write_version)
示例7: save_op
# 需要导入模块: from tensorflow.python.ops import io_ops [as 别名]
# 或者: from tensorflow.python.ops.io_ops import save_v2 [as 别名]
def save_op(self, filename_tensor, saveables):
"""Create an Op to save 'saveables'.
This is intended to be overridden by subclasses that want to generate
different Ops.
Args:
filename_tensor: String Tensor.
saveables: A list of BaseSaverBuilder.SaveableObject objects.
Returns:
An Operation that save the variables.
Raises:
RuntimeError: (implementation detail) if "self._write_version" is an
unexpected value.
"""
# pylint: disable=protected-access
tensor_names = []
tensors = []
tensor_slices = []
for saveable in saveables:
for spec in saveable.specs:
tensor_names.append(spec.name)
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
if self._write_version == saver_pb2.SaverDef.V1:
return io_ops._save(
filename=filename_tensor,
tensor_names=tensor_names,
tensors=tensors,
tensor_slices=tensor_slices)
elif self._write_version == saver_pb2.SaverDef.V2:
# "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
# of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
tensors)
else:
raise RuntimeError("Unexpected write_version: " + self._write_version)
# pylint: disable=unused-argument