回調以備份和恢複訓練狀態。
繼承自:Callback
用法
tf.keras.callbacks.BackupAndRestore(
backup_dir
)
參數
-
backup_dir
字符串,存儲檢查點的路徑。例如backup_dir = os.path.join(working_dir, 'backup') 這是係統存儲臨時文件以從意外終止的作業中恢複模型的目錄。該目錄不能在其他地方重複使用來存儲其他文件,例如通過另一個訓練的 BackupAndRestore 回調,或通過同一訓練的另一個回調 (ModelCheckpoint)。
BackupAndRestore
回調旨在通過將訓練狀態備份到臨時檢查點文件中(在 tf.train.CheckpointManager
的幫助下),從 Model.fit
執行過程中發生的中斷中恢複訓練,最後每個時代的。每個備份都會覆蓋先前寫入的檢查點文件,因此在任何給定時間,最多有一個這樣的檢查點文件用於備份/恢複目的。
如果訓練在完成之前重新開始,則訓練狀態(包括 Model
權重和 epoch 編號)將在新的 Model.fit
運行開始時恢複到最近保存的狀態。 Model.fit
運行完成後,臨時檢查點文件將被刪除。
請注意,用戶有責任在中斷後恢複作業。這個回調對於容錯目的的備份和恢複機製很重要,並且從前一個檢查點恢複的模型預計與用於備份的模型相同。如果用戶更改傳遞給 compile 或 fit 的參數,則為容錯保存的檢查點可能會變得無效。
注意:
- 此回調與禁用的即刻執行不兼容。
- 在每個 epoch 結束時保存一個檢查點。恢複後,
Model.fit
在訓練重新開始的未完成時期重做任何部分工作(因此中斷前完成的工作不會影響最終模型狀態)。 - 這適用於單工作人員和multi-worker 模式。當
Model.fit
與tf.distribute
一起使用時,它支持tf.distribute.MirroredStrategy
、tf.distribute.MultiWorkerMirroredStrategy
、tf.distribute.TPUStrategy
和tf.distribute.experimental.ParameterServerStrategy
。
例子:
class InterruptingCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
if epoch == 4:
raise RuntimeError('Interrupting!')
callback = tf.keras.callbacks.experimental.BackupAndRestore(
backup_dir="/tmp/backup")
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
try:
model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
batch_size=1, callbacks=[callback, InterruptingCallback()],
verbose=0)
except:
pass
history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
batch_size=1, callbacks=[callback], verbose=0)
# Only 6 more epochs are run, since first trainning got interrupted at
# zero-indexed epoch 4, second training will continue from 4 to 9.
len(history.history['loss'])
6
相關用法
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代碼示例
- Python tf.keras.callbacks.EarlyStopping用法及代碼示例
- Python tf.keras.callbacks.CSVLogger用法及代碼示例
- Python tf.keras.callbacks.TensorBoard用法及代碼示例
- Python tf.keras.callbacks.Callback用法及代碼示例
- Python tf.keras.callbacks.ModelCheckpoint用法及代碼示例
- Python tf.keras.callbacks.LambdaCallback用法及代碼示例
- Python tf.keras.callbacks.LearningRateScheduler用法及代碼示例
- Python tf.keras.callbacks.History用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
- Python tf.keras.layers.InputLayer用法及代碼示例
- Python tf.keras.layers.serialize用法及代碼示例
- Python tf.keras.metrics.Hinge用法及代碼示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代碼示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代碼示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代碼示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.callbacks.BackupAndRestore。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。