管理将可跟踪值保存/恢复到磁盘。
用法
tf.train.Checkpoint(
root=None, **kwargs
)
参数
-
root
检查点的根对象。 -
**kwargs
关键字参数设置为此对象的属性,并与检查点一起保存。值必须是可跟踪的对象。
抛出
-
ValueError
如果root
或kwargs
中的对象不可追踪。如果root
对象跟踪的对象与 kwargs 中属性中列出的对象不同(例如,root.child = A
和tf.train.Checkpoint(root, child=B)
不兼容),也会引发ValueError
。
属性
-
save_counter
调用save()
时增加。用于对检查点进行编号。
TensorFlow 对象可能包含可跟踪状态,例如 tf.Variable
s、tf.keras.optimizers.Optimizer
implementations、tf.data.Dataset
iterators、tf.keras.Layer
implementations 或 tf.keras.Model
implementations。这些被称为可跟踪对象。
可以构造 Checkpoint
对象以将单个或一组可跟踪对象保存到检查点文件。它维护一个save_counter
用于对检查点进行编号。
例子:
model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)
# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')
# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)
示例 2:
import tensorflow as tf
import os
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
optimizer.minimize( ... ) # Variables will be restored on creation.
status.assert_consumed() # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)
Checkpoint.save()
和 Checkpoint.restore()
写入和读取基于对象的检查点,而 TensorFlow 1.x 的 tf.compat.v1.train.Saver
写入和读取基于 variable.name
的检查点。基于对象的检查点保存带有命名边的 Python 对象(Layer
s、Optimizer
s、Variable
s 等)之间的依赖关系图,该图用于在恢复检查点时匹配变量。它可以对 Python 程序中的更改更加健壮,并有助于支持变量的restore-on-create。
Checkpoint
对象依赖于作为关键字参数传递给其构造函数的对象,并且每个依赖项的名称都与创建它的关键字参数的名称相同。 Layer
和 Optimizer
等 TensorFlow 类将自动添加对它们自己变量的依赖项(例如 "kernel" 和 "bias" 用于 tf.keras.layers.Dense
)。从tf.keras.Model
继承使得在用户定义的类中管理依赖关系变得容易,因为Model
挂钩到属性分配。例如:
class Regress(tf.keras.Model):
def __init__(self):
super(Regress, self).__init__()
self.input_transform = tf.keras.layers.Dense(10)
# ...
def call(self, inputs):
x = self.input_transform(inputs)
# ...
这个Model
在它的Dense
层上有一个名为"input_transform" 的依赖关系,它又依赖于它的变量。因此,使用tf.train.Checkpoint
保存Regress
的实例也将保存Dense
层创建的所有变量。
当变量分配给多个工作人员时,每个工作人员都会编写自己的检查点部分。然后将这些部分合并/重新索引以充当单个检查点。这避免了将所有变量复制到一个工作人员,但确实要求所有工作人员都看到一个公共文件系统。
此函数与 Keras Model save_weights
函数略有不同。 tf.keras.Model.save_weights
使用在 filepath
中指定的名称创建检查点文件,而 tf.train.Checkpoint
使用 filepath
作为检查点文件名的前缀对检查点进行编号。除此之外,model.save_weights()
和tf.train.Checkpoint(model).save()
是等价的。
有关详细信息,请参阅训练检查点指南。
相关用法
- Python tf.train.Checkpoint.restore用法及代码示例
- Python tf.train.Checkpoint.read用法及代码示例
- Python tf.train.CheckpointOptions用法及代码示例
- Python tf.train.Checkpoint.save用法及代码示例
- Python tf.train.Checkpoint.write用法及代码示例
- Python tf.train.CheckpointManager用法及代码示例
- Python tf.train.Coordinator.stop_on_exception用法及代码示例
- Python tf.train.ClusterSpec用法及代码示例
- Python tf.train.Coordinator用法及代码示例
- Python tf.train.ExponentialMovingAverage用法及代码示例
- Python tf.train.list_variables用法及代码示例
- Python tf.transpose用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.train.Checkpoint。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。