对可跟踪对象进行分组,保存和恢复它们。
用法
tf.compat.v1.train.Checkpoint(
**kwargs
)
参数
-
**kwargs
关键字参数设置为此对象的属性,并与检查点一起保存。值必须是可跟踪的对象。
抛出
-
ValueError
如果kwargs
中的对象不可追踪。
属性
-
save_counter
调用save()
时增加。用于对检查点进行编号。
Checkpoint
的构造函数接受关键字参数,其值是包含可跟踪状态的类型,例如 tf.compat.v1.train.Optimizer
implementations、tf.Variable
、tf.keras.Layer
implementations 或 tf.keras.Model
implementations。它使用检查点保存这些值,并维护save_counter
用于编号检查点。
图构建时的示例用法:
import tensorflow as tf
import os
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
train_op = optimizer.minimize( ... )
status.assert_consumed() # Optional sanity checks.
with tf.compat.v1.Session() as session:
# Use the Session to restore variables, or initialize them if
# tf.train.latest_checkpoint returned None.
status.initialize_or_restore(session)
for _ in range(num_training_steps):
session.run(train_op)
checkpoint.save(file_prefix=checkpoint_prefix)
启用即刻执行的示例用法:
import tensorflow as tf
import os
tf.compat.v1.enable_eager_execution()
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
optimizer.minimize( ... ) # Variables will be restored on creation.
status.assert_consumed() # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)
Checkpoint.save
和 Checkpoint.restore
写入和读取基于对象的检查点,而 tf.compat.v1.train.Saver
写入和读取基于 variable.name
的检查点。基于对象的检查点保存带有命名边的 Python 对象(Layer
s、Optimizer
s、Variable
s 等)之间的依赖关系图,该图用于在恢复检查点时匹配变量。它可以对 Python 程序中的更改更加健壮,并有助于在即刻执行时支持变量的restore-on-create。对于新代码,首选 tf.train.Checkpoint
而不是 tf.compat.v1.train.Saver
。
Checkpoint
对象依赖于作为关键字参数传递给其构造函数的对象,并且每个依赖项的名称都与创建它的关键字参数的名称相同。 Layer
和 Optimizer
等 TensorFlow 类将自动添加对其变量的依赖项(例如 "kernel" 和 "bias" 用于 tf.keras.layers.Dense
)。从tf.keras.Model
继承使得在用户定义的类中管理依赖关系变得容易,因为Model
挂钩到属性分配。例如:
class Regress(tf.keras.Model):
def __init__(self):
super(Regress, self).__init__()
self.input_transform = tf.keras.layers.Dense(10)
# ...
def call(self, inputs):
x = self.input_transform(inputs)
# ...
这个Model
在它的Dense
层上有一个名为"input_transform" 的依赖关系,它又依赖于它的变量。因此,使用tf.train.Checkpoint
保存Regress
的实例也将保存Dense
层创建的所有变量。
当变量分配给多个工作人员时,每个工作人员都会编写自己的检查点部分。然后将这些部分合并/重新索引以充当单个检查点。这避免了将所有变量复制到一个工作人员,但确实要求所有工作人员都看到一个公共文件系统。
虽然 tf.keras.Model.save_weights
和 tf.train.Checkpoint.save
以相同的格式保存,但请注意生成的检查点的根是保存方法附加到的对象。这意味着使用 save_weights
保存 tf.keras.Model
并加载到带有 Model
的 tf.train.Checkpoint
(反之亦然)将不匹配 Model
的变量。有关详细信息,请参阅训练检查点指南。首选 tf.train.Checkpoint 而不是 tf.keras.Model.save_weights 用于训练检查点。
相关用法
- Python tf.compat.v1.train.Checkpoint.restore用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.get_or_create_global_step用法及代码示例
- Python tf.compat.v1.train.cosine_decay_restarts用法及代码示例
- Python tf.compat.v1.train.Optimizer用法及代码示例
- Python tf.compat.v1.train.AdagradOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.init_from_checkpoint用法及代码示例
- Python tf.compat.v1.train.Supervisor.managed_session用法及代码示例
- Python tf.compat.v1.train.global_step用法及代码示例
- Python tf.compat.v1.train.MonitoredSession.run_step_fn用法及代码示例
- Python tf.compat.v1.train.RMSPropOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.exponential_decay用法及代码示例
- Python tf.compat.v1.train.natural_exp_decay用法及代码示例
- Python tf.compat.v1.train.MomentumOptimizer用法及代码示例
- Python tf.compat.v1.train.RMSPropOptimizer用法及代码示例
- Python tf.compat.v1.train.get_global_step用法及代码示例
- Python tf.compat.v1.train.GradientDescentOptimizer.compute_gradients用法及代码示例
- Python tf.compat.v1.train.linear_cosine_decay用法及代码示例
- Python tf.compat.v1.train.Supervisor用法及代码示例
- Python tf.compat.v1.train.AdagradDAOptimizer.compute_gradients用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.train.Checkpoint。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。