当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.train.Checkpoint.restore用法及代码示例


用法

restore(
    save_path, options=None
)

参数

返回

  • 加载状态对象,可用于对检查点恢复的状态进行断言。

    返回的状态对象有以下方法:

    • assert_consumed() :如果任何变量不匹配,则引发异常:检查点值没有匹配的 Python 对象或依赖关系图中的 Python 对象在检查点中没有值。此方法返回状态对象,因此可以与其他断言链接。

    • assert_existing_objects_matched() :如果依赖图中的任何现有 Python 对象不匹配,则引发异常。与 assert_consumed 不同,如果检查点中的值没有对应的 Python 对象,则此断言将通过。例如,尚未构建的 tf.keras.Layer 对象,因此尚未创建任何变量,将通过此断言但失败 assert_consumed 。在将较大检查点的一部分加载到新的 Python 程序中时很有用,例如保存了带有 tf.compat.v1.train.Optimizer 的训练检查点,但仅加载了推理所需的状态。此方法返回状态对象,因此可以与其他断言链接。

    • assert_nontrivial_match() :断言除根对象之外的其他内容已匹配。这是一个非常弱的断言,但对于库代码中的健全性检查很有用,其中对象可能存在于检查点中,而这些对象可能尚未在 Python 中创建,并且某些 Python 对象可能没有检查点值。

    • expect_partial():关于不完整的检查点恢复的静默警告。当Checkpoint 对象被删除时(通常在程序关闭时),检查点文件或对象的未使用部分会打印警告。

抛出

  • NotFoundError 如果在 save_path 找不到检查点或 SavedModel。

恢复训练检查点。

恢复此 Checkpoint 及其依赖的任何对象。

此方法旨在用于加载由 save() 创建的检查点。对于由 write() 创建的检查点,请使用 read() 方法,该方法不需要 save() 添加的 save_counter 变量。

restore() 如果要恢复的变量已经创建,则立即赋值,或者推迟恢复直到创建变量。如果在检查点中有相应的对象,则在此调用之后添加的依赖项将被匹配(恢复请求将在任何可跟踪的对象中排队等待添加预期的依赖项)。

checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path)

# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options)

为确保加载完成并且不再发生延迟恢复,请使用 restore() 返回的状态对象的 assert_consumed() 方法:

checkpoint.restore(path, options=options).assert_consumed()

如果在检查点中未找到依赖图中的任何 Python 对象,或者任何检查点值没有匹配的 Python 对象,则断言将引发错误。

可以使用此方法加载来自 TensorFlow 1.x 的基于名称的 tf.compat.v1.train.Saver 检查点。名称用于匹配变量。尽快使用tf.train.Checkpoint.save 重新编码基于名称的检查点。

从 SavedModel 检查点加载

要从 SavedModel 加载值,只需将 SavedModel 目录传递给 checkpoint.restore:

model = tf.keras.Model(...)
tf.saved_model.save(model, path)  # or model.save(path, save_format='tf')

checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(path).expect_partial()

这个例子在加载状态上调用expect_partial(),因为从 Keras 保存的 SavedModels 通常会在检查点中生成额外的键。否则,程序会在退出时打印很多关于未使用 key 的警告。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.train.Checkpoint.restore。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。