当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.train.Saver用法及代码示例


保存和恢复变量。

用法

tf.compat.v1.train.Saver(
    var_list=None, reshape=False, sharded=False, max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False,
    saver_def=None, builder=None, defer_build=False, allow_empty=False,
    write_version=tf.train.SaverDef.V2, pad_step_number=False,
    save_relative_paths=False, filename=None
)

参数

  • var_list Variable /SaveableObject 的列表,或映射名称到 SaveableObject 的字典。如果 None ,默认为所有可保存对象的列表。
  • reshape 如果 True ,允许从变量具有不同形状的检查点恢复参数。
  • sharded 如果 True ,将检查点分片,每个设备一个。
  • max_to_keep 要保留的最近检查点的最大数量。默认为 5。
  • keep_checkpoint_every_n_hours 多久设置一次检查点。默认为 10,000 小时。
  • name String 。添加操作时用作前缀的可选名称。
  • restore_sequentially A Bool ,如果为真,则会导致在每个设备内按顺序恢复不同的变量。这可以在恢复非常大的模型时降低内存使用量。
  • saver_def 使用可选的SaverDef proto 而不是运行构建器。这仅对想要为先前构建的具有 SaverGraph 重新创建 Saver 对象的特殊代码有用。 saver_def 原型应该是为该 Graph 创建的 Saveras_saver_def() 调用返回的原型。
  • builder 如果未提供 saver_def,则使用可选的 SaverBuilder。默认为 BulkSaverBuilder()
  • defer_build 如果 True ,请推迟将保存和恢复操作添加到 build() 调用。在这种情况下,应在完成图形或使用保护程序之前调用 build()
  • allow_empty 如果 False(默认)在图中没有变量的情况下引发错误。否则,无论如何构建保护程序并将其设为no-op。
  • write_version 控制保存检查点时使用的格式。它还会影响某些文件路径匹配逻辑。推荐选择 V2 格式:在所需的内存和还原期间产生的延迟方面,它比 V1 优化得多。无论此标志如何,Saver 都能够从 V2 和 V1 检查点恢复。
  • pad_step_number 如果为 True,则将检查点文件路径中的全局步骤编号填充到某个固定宽度(默认为 8)。这是默认关闭的。
  • save_relative_paths 如果 True ,将写入检查点状态文件的相对路径。如果用户想要复制检查点目录并从复制的目录重新加载,则需要这样做。
  • filename 如果在图形构建时已知,则用于变量加载/保存的文件名。

抛出

  • TypeError 如果var_list 无效。
  • ValueError 如果 var_list 中的任何键或值不是唯一的。
  • RuntimeError 如果启用了即刻执行并且var_list 未指定要保存的变量列表。

属性

  • last_checkpoints not-yet-deleted 检查点文件名列表。

    您可以将任何返回值传递给 restore()

迁移到 TF2

警告:这个 API 是为 TensorFlow v1 设计的。继续阅读有关如何从该 API 迁移到本机 TensorFlow v2 等效项的详细信息。见TensorFlow v1 到 TensorFlow v2 迁移指南有关如何迁移其余代码的说明。

tf.compat.v1.train.Saver 不支持在 TF2 中保存和恢复检查点。请切换到 tf.train.Checkpointtf.keras.Model.save_weights ,它们执行更强大的基于对象的保存。

如何重写检查点

请立即使用基于对象的检查点 API 重写您的检查点。

您可以使用 tf.train.Checkpoint.restoretf.keras.Model.load_weights 加载由 tf.compat.v1.train.Saver 编写的基于名称的检查点。但是,您可能必须更改模型中的变量名称以匹配基于名称的检查点中的变量名称,可以使用 tf.train.list_variables(path) 查看。

另一种选择是创建一个assignment_map,将基于名称的检查点中的变量名称映射到模型中的变量,例如:

{
    'sequential/dense/bias':model.variables[0],
    'sequential/dense/kernel':model.variables[1]
}

并使用tf.compat.v1.train.init_from_checkpoint恢复基于名称的检查点。

恢复后,使用 tf.train.Checkpoint.savetf.keras.Model.save_weights 重新编码您的检查点。

有关更多详细信息,请参阅迁移指南的检查点兼容性部分。

TF2 中的检查点管理

使用tf.train.CheckpointManager 管理 TF2 中的检查点。 tf.train.CheckpointManager 提供等效的 keep_checkpoint_every_n_hoursmax_to_keep 参数。

要恢复最新的检查点,

checkpoint = tf.train.Checkpoint(model)
manager = tf.train.CheckpointManager(checkpoint)
status = checkpoint.restore(manager.latest_checkpoint)

tf.train.CheckpointManager 还编写了一个 CheckpointState 原型,其中包含每个检查点创建时的时间戳。

在 TF2 中写入 MetaGraphDef

要替换 tf.compat.v1.train.Saver.save(write_meta_graph=True) ,请使用 tf.saved_model.save 编写 MetaGraphDef (包含在 saved_model.pb 中)。

有关变量、保存和恢复的概述,请参阅变量。

Saver 类添加了操作来保存和恢复检查点的变量。它还提供了运行这些操作的便捷方法。

检查点是专有格式的二进制文件,将变量名称映射到张量值。检查检查点内容的最佳方法是使用 Saver 加载它。

Savers 可以使用提供的计数器自动编号检查点文件名。这使您可以在训练模型时在不同的步骤中保留多个检查点。例如,您可以使用训练步骤编号对检查点文件名进行编号。为了避免填满磁盘,保存程序会自动管理检查点文件。例如,他们可以只保留 N 个最近的文件,或者每 N 小时的训练一个检查点。

通过将值传递给可选的 global_step 参数给 save() 来为检查点文件名编号:

saver.save(sess, 'my-model', global_step=0) ==> filename:'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename:'my-model-1000'

此外,Saver() 构造函数的可选参数允许您控制磁盘上检查点文件的扩散:

  • max_to_keep 表示要保留的最近检查点文件的最大数量。创建新文件时,会删除旧文件。如果 None 或 0,则不会从文件系统中删除检查点,但只有最后一个保留在 checkpoint 文件中。默认为 5(即保留 5 个最近的检查点文件。)

  • keep_checkpoint_every_n_hours :除了保留最新的max_to_keep 检查点文件外,您可能希望每训练 N 小时保留一个检查点文件。如果您想稍后分析模型在长时间训练期间的进展情况,这将很有用。例如,通过 keep_checkpoint_every_n_hours=2 可确保您为每 2 小时的训练保留一个检查点文件。默认值 10,000 小时有效地禁用了该函数。

请注意,您仍然必须调用save() 方法来保存模型。将这些参数传递给构造函数不会自动为您保存变量。

定期保存的训练计划如下所示:

...
# Create a saver.
saver = tf.compat.v1.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.compat.v1.Session()
for step in range(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

除了检查点文件之外,保存程序还在磁盘上保留一个协议缓冲区,其中包含最近检查点的列表。这用于管理编号的检查点文件和 latest_checkpoint() ,这使得发现最近检查点的路径变得容易。该协议缓冲区存储在检查点文件旁边名为'checkpoint' 的文件中。

如果创建多个保存程序,则可以在调用 save() 时为协议缓冲区文件指定不同的文件名。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.train.Saver。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。