回调以备份和恢复训练状态。
继承自:Callback
用法
tf.keras.callbacks.BackupAndRestore(
backup_dir
)
参数
-
backup_dir
字符串,存储检查点的路径。例如backup_dir = os.path.join(working_dir, 'backup') 这是系统存储临时文件以从意外终止的作业中恢复模型的目录。该目录不能在其他地方重复使用来存储其他文件,例如通过另一个训练的 BackupAndRestore 回调,或通过同一训练的另一个回调 (ModelCheckpoint)。
BackupAndRestore
回调旨在通过将训练状态备份到临时检查点文件中(在 tf.train.CheckpointManager
的帮助下),从 Model.fit
执行过程中发生的中断中恢复训练,最后每个时代的。每个备份都会覆盖先前写入的检查点文件,因此在任何给定时间,最多有一个这样的检查点文件用于备份/恢复目的。
如果训练在完成之前重新开始,则训练状态(包括 Model
权重和 epoch 编号)将在新的 Model.fit
运行开始时恢复到最近保存的状态。 Model.fit
运行完成后,临时检查点文件将被删除。
请注意,用户有责任在中断后恢复作业。这个回调对于容错目的的备份和恢复机制很重要,并且从前一个检查点恢复的模型预计与用于备份的模型相同。如果用户更改传递给 compile 或 fit 的参数,则为容错保存的检查点可能会变得无效。
注意:
- 此回调与禁用的即刻执行不兼容。
- 在每个 epoch 结束时保存一个检查点。恢复后,
Model.fit
在训练重新开始的未完成时期重做任何部分工作(因此中断前完成的工作不会影响最终模型状态)。 - 这适用于单工作人员和multi-worker 模式。当
Model.fit
与tf.distribute
一起使用时,它支持tf.distribute.MirroredStrategy
、tf.distribute.MultiWorkerMirroredStrategy
、tf.distribute.TPUStrategy
和tf.distribute.experimental.ParameterServerStrategy
。
例子:
class InterruptingCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
if epoch == 4:
raise RuntimeError('Interrupting!')
callback = tf.keras.callbacks.experimental.BackupAndRestore(
backup_dir="/tmp/backup")
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
try:
model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
batch_size=1, callbacks=[callback, InterruptingCallback()],
verbose=0)
except:
pass
history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
batch_size=1, callbacks=[callback], verbose=0)
# Only 6 more epochs are run, since first trainning got interrupted at
# zero-indexed epoch 4, second training will continue from 4 to 9.
len(history.history['loss'])
6
相关用法
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代码示例
- Python tf.keras.callbacks.EarlyStopping用法及代码示例
- Python tf.keras.callbacks.CSVLogger用法及代码示例
- Python tf.keras.callbacks.TensorBoard用法及代码示例
- Python tf.keras.callbacks.Callback用法及代码示例
- Python tf.keras.callbacks.ModelCheckpoint用法及代码示例
- Python tf.keras.callbacks.LambdaCallback用法及代码示例
- Python tf.keras.callbacks.LearningRateScheduler用法及代码示例
- Python tf.keras.callbacks.History用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.keras.layers.serialize用法及代码示例
- Python tf.keras.metrics.Hinge用法及代码示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代码示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代码示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代码示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.callbacks.BackupAndRestore。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。