用法
restore(
save_path, options=None
)
参数
-
save_path
检查点的路径,由save
或tf.train.latest_checkpoint
返回。如果检查点是由基于名称的tf.compat.v1.train.Saver
编写的,则名称用于匹配变量。此路径也可能是 SavedModel 目录。 -
options
可选的tf.train.CheckpointOptions
对象。
返回
-
加载状态对象,可用于对检查点恢复的状态进行断言。
返回的状态对象有以下方法:
assert_consumed()
:如果任何变量不匹配,则引发异常:检查点值没有匹配的 Python 对象或依赖关系图中的 Python 对象在检查点中没有值。此方法返回状态对象,因此可以与其他断言链接。assert_existing_objects_matched()
:如果依赖图中的任何现有 Python 对象不匹配,则引发异常。与assert_consumed
不同,如果检查点中的值没有对应的 Python 对象,则此断言将通过。例如,尚未构建的tf.keras.Layer
对象,因此尚未创建任何变量,将通过此断言但失败assert_consumed
。在将较大检查点的一部分加载到新的 Python 程序中时很有用,例如保存了带有tf.compat.v1.train.Optimizer
的训练检查点,但仅加载了推理所需的状态。此方法返回状态对象,因此可以与其他断言链接。assert_nontrivial_match()
:断言除根对象之外的其他内容已匹配。这是一个非常弱的断言,但对于库代码中的健全性检查很有用,其中对象可能存在于检查点中,而这些对象可能尚未在 Python 中创建,并且某些 Python 对象可能没有检查点值。expect_partial()
:关于不完整的检查点恢复的静默警告。当Checkpoint
对象被删除时(通常在程序关闭时),检查点文件或对象的未使用部分会打印警告。
抛出
-
NotFoundError
如果在save_path
找不到检查点或 SavedModel。
恢复训练检查点。
恢复此 Checkpoint
及其依赖的任何对象。
此方法旨在用于加载由 save()
创建的检查点。对于由 write()
创建的检查点,请使用 read()
方法,该方法不需要 save()
添加的 save_counter
变量。
restore()
如果要恢复的变量已经创建,则立即赋值,或者推迟恢复直到创建变量。如果在检查点中有相应的对象,则在此调用之后添加的依赖项将被匹配(恢复请求将在任何可跟踪的对象中排队等待添加预期的依赖项)。
checkpoint = tf.train.Checkpoint( ... )
checkpoint.restore(path)
# You can additionally pass options to restore():
options = tf.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.restore(path, options=options)
为确保加载完成并且不再发生延迟恢复,请使用 restore()
返回的状态对象的 assert_consumed()
方法:
checkpoint.restore(path, options=options).assert_consumed()
如果在检查点中未找到依赖图中的任何 Python 对象,或者任何检查点值没有匹配的 Python 对象,则断言将引发错误。
可以使用此方法加载来自 TensorFlow 1.x 的基于名称的 tf.compat.v1.train.Saver
检查点。名称用于匹配变量。尽快使用tf.train.Checkpoint.save
重新编码基于名称的检查点。
从 SavedModel 检查点加载
要从 SavedModel 加载值,只需将 SavedModel 目录传递给 checkpoint.restore:
model = tf.keras.Model(...)
tf.saved_model.save(model, path) # or model.save(path, save_format='tf')
checkpoint = tf.train.Checkpoint(model)
checkpoint.restore(path).expect_partial()
这个例子在加载状态上调用expect_partial()
,因为从 Keras 保存的 SavedModels 通常会在检查点中生成额外的键。否则,程序会在退出时打印很多关于未使用 key 的警告。
相关用法
- Python tf.train.Checkpoint.read用法及代码示例
- Python tf.train.Checkpoint.save用法及代码示例
- Python tf.train.Checkpoint.write用法及代码示例
- Python tf.train.CheckpointOptions用法及代码示例
- Python tf.train.CheckpointManager用法及代码示例
- Python tf.train.Checkpoint用法及代码示例
- Python tf.train.Coordinator.stop_on_exception用法及代码示例
- Python tf.train.ClusterSpec用法及代码示例
- Python tf.train.Coordinator用法及代码示例
- Python tf.train.ExponentialMovingAverage用法及代码示例
- Python tf.train.list_variables用法及代码示例
- Python tf.transpose用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.train.Checkpoint.restore。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。