管理將可跟蹤值保存/恢複到磁盤。
用法
tf.train.Checkpoint(
    root=None, **kwargs
)參數
- 
root檢查點的根對象。
- 
**kwargs關鍵字參數設置為此對象的屬性,並與檢查點一起保存。值必須是可跟蹤的對象。
拋出
- 
ValueError如果root或kwargs中的對象不可追蹤。如果root對象跟蹤的對象與 kwargs 中屬性中列出的對象不同(例如,root.child = A和tf.train.Checkpoint(root, child=B)不兼容),也會引發ValueError。
屬性
- 
save_counter調用save()時增加。用於對檢查點進行編號。
TensorFlow 對象可能包含可跟蹤狀態,例如 tf.Variable s、tf.keras.optimizers.Optimizer implementations、tf.data.Dataset iterators、tf.keras.Layer implementations 或 tf.keras.Model implementations。這些被稱為可跟蹤對象。
可以構造 Checkpoint 對象以將單個或一組可跟蹤對象保存到檢查點文件。它維護一個save_counter 用於對檢查點進行編號。
例子:
model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)
# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')
# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)示例 2:
import tensorflow as tf
import os
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)Checkpoint.save() 和 Checkpoint.restore() 寫入和讀取基於對象的檢查點,而 TensorFlow 1.x 的 tf.compat.v1.train.Saver 寫入和讀取基於 variable.name 的檢查點。基於對象的檢查點保存帶有命名邊的 Python 對象(Layer s、Optimizer s、Variable s 等)之間的依賴關係圖,該圖用於在恢複檢查點時匹配變量。它可以對 Python 程序中的更改更加健壯,並有助於支持變量的restore-on-create。
Checkpoint 對象依賴於作為關鍵字參數傳遞給其構造函數的對象,並且每個依賴項的名稱都與創建它的關鍵字參數的名稱相同。 Layer 和 Optimizer 等 TensorFlow 類將自動添加對它們自己變量的依賴項(例如 "kernel" 和 "bias" 用於 tf.keras.layers.Dense )。從tf.keras.Model 繼承使得在用戶定義的類中管理依賴關係變得容易,因為Model 掛鉤到屬性分配。例如:
class Regress(tf.keras.Model):
  def __init__(self):
    super(Regress, self).__init__()
    self.input_transform = tf.keras.layers.Dense(10)
    # ...
  def call(self, inputs):
    x = self.input_transform(inputs)
    # ...這個Model 在它的Dense 層上有一個名為"input_transform" 的依賴關係,它又依賴於它的變量。因此,使用tf.train.Checkpoint 保存Regress 的實例也將保存Dense 層創建的所有變量。
當變量分配給多個工作人員時,每個工作人員都會編寫自己的檢查點部分。然後將這些部分合並/重新索引以充當單個檢查點。這避免了將所有變量複製到一個工作人員,但確實要求所有工作人員都看到一個公共文件係統。
此函數與 Keras Model save_weights 函數略有不同。 tf.keras.Model.save_weights 使用在 filepath 中指定的名稱創建檢查點文件,而 tf.train.Checkpoint 使用 filepath 作為檢查點文件名的前綴對檢查點進行編號。除此之外,model.save_weights() 和tf.train.Checkpoint(model).save() 是等價的。
有關詳細信息,請參閱訓練檢查點指南。
相關用法
- Python tf.train.Checkpoint.restore用法及代碼示例
- Python tf.train.Checkpoint.read用法及代碼示例
- Python tf.train.CheckpointOptions用法及代碼示例
- Python tf.train.Checkpoint.save用法及代碼示例
- Python tf.train.Checkpoint.write用法及代碼示例
- Python tf.train.CheckpointManager用法及代碼示例
- Python tf.train.Coordinator.stop_on_exception用法及代碼示例
- Python tf.train.ClusterSpec用法及代碼示例
- Python tf.train.Coordinator用法及代碼示例
- Python tf.train.ExponentialMovingAverage用法及代碼示例
- Python tf.train.list_variables用法及代碼示例
- Python tf.transpose用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.train.Checkpoint。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。
