管理將可跟蹤值保存/恢複到磁盤。
用法
tf.train.Checkpoint(
root=None, **kwargs
)
參數
-
root
檢查點的根對象。 -
**kwargs
關鍵字參數設置為此對象的屬性,並與檢查點一起保存。值必須是可跟蹤的對象。
拋出
-
ValueError
如果root
或kwargs
中的對象不可追蹤。如果root
對象跟蹤的對象與 kwargs 中屬性中列出的對象不同(例如,root.child = A
和tf.train.Checkpoint(root, child=B)
不兼容),也會引發ValueError
。
屬性
-
save_counter
調用save()
時增加。用於對檢查點進行編號。
TensorFlow 對象可能包含可跟蹤狀態,例如 tf.Variable
s、tf.keras.optimizers.Optimizer
implementations、tf.data.Dataset
iterators、tf.keras.Layer
implementations 或 tf.keras.Model
implementations。這些被稱為可跟蹤對象。
可以構造 Checkpoint
對象以將單個或一組可跟蹤對象保存到檢查點文件。它維護一個save_counter
用於對檢查點進行編號。
例子:
model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)
# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')
# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)
示例 2:
import tensorflow as tf
import os
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
optimizer.minimize( ... ) # Variables will be restored on creation.
status.assert_consumed() # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)
Checkpoint.save()
和 Checkpoint.restore()
寫入和讀取基於對象的檢查點,而 TensorFlow 1.x 的 tf.compat.v1.train.Saver
寫入和讀取基於 variable.name
的檢查點。基於對象的檢查點保存帶有命名邊的 Python 對象(Layer
s、Optimizer
s、Variable
s 等)之間的依賴關係圖,該圖用於在恢複檢查點時匹配變量。它可以對 Python 程序中的更改更加健壯,並有助於支持變量的restore-on-create。
Checkpoint
對象依賴於作為關鍵字參數傳遞給其構造函數的對象,並且每個依賴項的名稱都與創建它的關鍵字參數的名稱相同。 Layer
和 Optimizer
等 TensorFlow 類將自動添加對它們自己變量的依賴項(例如 "kernel" 和 "bias" 用於 tf.keras.layers.Dense
)。從tf.keras.Model
繼承使得在用戶定義的類中管理依賴關係變得容易,因為Model
掛鉤到屬性分配。例如:
class Regress(tf.keras.Model):
def __init__(self):
super(Regress, self).__init__()
self.input_transform = tf.keras.layers.Dense(10)
# ...
def call(self, inputs):
x = self.input_transform(inputs)
# ...
這個Model
在它的Dense
層上有一個名為"input_transform" 的依賴關係,它又依賴於它的變量。因此,使用tf.train.Checkpoint
保存Regress
的實例也將保存Dense
層創建的所有變量。
當變量分配給多個工作人員時,每個工作人員都會編寫自己的檢查點部分。然後將這些部分合並/重新索引以充當單個檢查點。這避免了將所有變量複製到一個工作人員,但確實要求所有工作人員都看到一個公共文件係統。
此函數與 Keras Model save_weights
函數略有不同。 tf.keras.Model.save_weights
使用在 filepath
中指定的名稱創建檢查點文件,而 tf.train.Checkpoint
使用 filepath
作為檢查點文件名的前綴對檢查點進行編號。除此之外,model.save_weights()
和tf.train.Checkpoint(model).save()
是等價的。
有關詳細信息,請參閱訓練檢查點指南。
相關用法
- Python tf.train.Checkpoint.restore用法及代碼示例
- Python tf.train.Checkpoint.read用法及代碼示例
- Python tf.train.CheckpointOptions用法及代碼示例
- Python tf.train.Checkpoint.save用法及代碼示例
- Python tf.train.Checkpoint.write用法及代碼示例
- Python tf.train.CheckpointManager用法及代碼示例
- Python tf.train.Coordinator.stop_on_exception用法及代碼示例
- Python tf.train.ClusterSpec用法及代碼示例
- Python tf.train.Coordinator用法及代碼示例
- Python tf.train.ExponentialMovingAverage用法及代碼示例
- Python tf.train.list_variables用法及代碼示例
- Python tf.transpose用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.train.Checkpoint。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。