当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.train.Checkpoint用法及代码示例


管理将可跟踪值保存/恢复到磁盘。

用法

tf.train.Checkpoint(
    root=None, **kwargs
)

参数

  • root 检查点的根对象。
  • **kwargs 关键字参数设置为此对象的属性,并与检查点一起保存。值必须是可跟踪的对象。

抛出

  • ValueError 如果rootkwargs 中的对象不可追踪。如果 root 对象跟踪的对象与 kwargs 中属性中列出的对象不同(例如,root.child = Atf.train.Checkpoint(root, child=B) 不兼容),也会引发 ValueError

属性

  • save_counter 调用 save() 时增加。用于对检查点进行编号。

TensorFlow 对象可能包含可跟踪状态,例如 tf.Variable s、tf.keras.optimizers.Optimizer implementations、tf.data.Dataset iterators、tf.keras.Layer implementations 或 tf.keras.Model implementations。这些被称为可跟踪对象。

可以构造 Checkpoint 对象以将单个或一组可跟踪对象保存到检查点文件。它维护一个save_counter 用于对检查点进行编号。

例子:

model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)

# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')

# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)

示例 2:

import tensorflow as tf
import os

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Checkpoint.save()Checkpoint.restore() 写入和读取基于对象的检查点,而 TensorFlow 1.x 的 tf.compat.v1.train.Saver 写入和读取基于 variable.name 的检查点。基于对象的检查点保存带有命名边的 Python 对象(Layer s、Optimizer s、Variable s 等)之间的依赖关系图,该图用于在恢复检查点时匹配变量。它可以对 Python 程序中的更改更加健壮,并有助于支持变量的restore-on-create。

Checkpoint 对象依赖于作为关键字参数传递给其构造函数的对象,并且每个依赖项的名称都与创建它的关键字参数的名称相同。 LayerOptimizer 等 TensorFlow 类将自动添加对它们自己变量的依赖项(例如 "kernel" 和 "bias" 用于 tf.keras.layers.Dense )。从tf.keras.Model 继承使得在用户定义的类中管理依赖关系变得容易,因为Model 挂钩到属性分配。例如:

class Regress(tf.keras.Model):

  def __init__(self):
    super(Regress, self).__init__()
    self.input_transform = tf.keras.layers.Dense(10)
    # ...

  def call(self, inputs):
    x = self.input_transform(inputs)
    # ...

这个Model 在它的Dense 层上有一个名为"input_transform" 的依赖关系,它又依赖于它的变量。因此,使用tf.train.Checkpoint 保存Regress 的实例也将保存Dense 层创建的所有变量。

当变量分配给多个工作人员时,每个工作人员都会编写自己的检查点部分。然后将这些部分合并/重新索引以充当单个检查点。这避免了将所有变量复制到一个工作人员,但确实要求所有工作人员都看到一个公共文件系统。

此函数与 Keras Model save_weights 函数略有不同。 tf.keras.Model.save_weights 使用在 filepath 中指定的名称创建检查点文件,而 tf.train.Checkpoint 使用 filepath 作为检查点文件名的前缀对检查点进行编号。除此之外,model.save_weights()tf.train.Checkpoint(model).save() 是等价的。

有关详细信息,请参阅训练检查点指南。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.train.Checkpoint。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。