当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.train.Checkpoint用法及代码示例


对可跟踪对象进行分组,保存和恢复它们。

用法

tf.compat.v1.train.Checkpoint(
    **kwargs
)

参数

  • **kwargs 关键字参数设置为此对象的属性,并与检查点一起保存。值必须是可跟踪的对象。

抛出

  • ValueError 如果kwargs 中的对象不可追踪。

属性

  • save_counter 调用 save() 时增加。用于对检查点进行编号。

Checkpoint 的构造函数接受关键字参数,其值是包含可跟踪状态的类型,例如 tf.compat.v1.train.Optimizer implementations、tf.Variabletf.keras.Layer implementations 或 tf.keras.Model implementations。它使用检查点保存这些值,并维护save_counter 用于编号检查点。

图构建时的示例用法:

import tensorflow as tf
import os

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
train_op = optimizer.minimize( ... )
status.assert_consumed()  # Optional sanity checks.
with tf.compat.v1.Session() as session:
  # Use the Session to restore variables, or initialize them if
  # tf.train.latest_checkpoint returned None.
  status.initialize_or_restore(session)
  for _ in range(num_training_steps):
    session.run(train_op)
  checkpoint.save(file_prefix=checkpoint_prefix)

启用即刻执行的示例用法:

import tensorflow as tf
import os

tf.compat.v1.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Checkpoint.saveCheckpoint.restore 写入和读取基于对象的检查点,而 tf.compat.v1.train.Saver 写入和读取基于 variable.name 的检查点。基于对象的检查点保存带有命名边的 Python 对象(Layer s、Optimizer s、Variable s 等)之间的依赖关系图,该图用于在恢复检查点时匹配变量。它可以对 Python 程序中的更改更加健壮,并有助于在即刻执行时支持变量的restore-on-create。对于新代码,首选 tf.train.Checkpoint 而不是 tf.compat.v1.train.Saver

Checkpoint 对象依赖于作为关键字参数传递给其构造函数的对象,并且每个依赖项的名称都与创建它的关键字参数的名称相同。 LayerOptimizer 等 TensorFlow 类将自动添加对其变量的依赖项(例如 "kernel" 和 "bias" 用于 tf.keras.layers.Dense )。从tf.keras.Model 继承使得在用户定义的类中管理依赖关系变得容易,因为Model 挂钩到属性分配。例如:

class Regress(tf.keras.Model):

  def __init__(self):
    super(Regress, self).__init__()
    self.input_transform = tf.keras.layers.Dense(10)
    # ...

  def call(self, inputs):
    x = self.input_transform(inputs)
    # ...

这个Model 在它的Dense 层上有一个名为"input_transform" 的依赖关系,它又依赖于它的变量。因此,使用tf.train.Checkpoint 保存Regress 的实例也将保存Dense 层创建的所有变量。

当变量分配给多个工作人员时,每个工作人员都会编写自己的检查点部分。然后将这些部分合并/重新索引以充当单个检查点。这避免了将所有变量复制到一个工作人员,但确实要求所有工作人员都看到一个公共文件系统。

虽然 tf.keras.Model.save_weightstf.train.Checkpoint.save 以相同的格式保存,但请注意生成的检查点的根是保存方法附加到的对象。这意味着使用 save_weights 保存 tf.keras.Model 并加载到带有 Modeltf.train.Checkpoint (反之亦然)将不匹配 Model 的变量。有关详细信息,请参阅训练检查点指南。首选 tf.train.Checkpoint 而不是 tf.keras.Model.save_weights 用于训练检查点。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.train.Checkpoint。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。