當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.train.Checkpoint用法及代碼示例


對可跟蹤對象進行分組,保存和恢複它們。

用法

tf.compat.v1.train.Checkpoint(
    **kwargs
)

參數

  • **kwargs 關鍵字參數設置為此對象的屬性,並與檢查點一起保存。值必須是可跟蹤的對象。

拋出

  • ValueError 如果kwargs 中的對象不可追蹤。

屬性

  • save_counter 調用 save() 時增加。用於對檢查點進行編號。

Checkpoint 的構造函數接受關鍵字參數,其值是包含可跟蹤狀態的類型,例如 tf.compat.v1.train.Optimizer implementations、tf.Variabletf.keras.Layer implementations 或 tf.keras.Model implementations。它使用檢查點保存這些值,並維護save_counter 用於編號檢查點。

圖構建時的示例用法:

import tensorflow as tf
import os

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
train_op = optimizer.minimize( ... )
status.assert_consumed()  # Optional sanity checks.
with tf.compat.v1.Session() as session:
  # Use the Session to restore variables, or initialize them if
  # tf.train.latest_checkpoint returned None.
  status.initialize_or_restore(session)
  for _ in range(num_training_steps):
    session.run(train_op)
  checkpoint.save(file_prefix=checkpoint_prefix)

啟用即刻執行的示例用法:

import tensorflow as tf
import os

tf.compat.v1.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Checkpoint.saveCheckpoint.restore 寫入和讀取基於對象的檢查點,而 tf.compat.v1.train.Saver 寫入和讀取基於 variable.name 的檢查點。基於對象的檢查點保存帶有命名邊的 Python 對象(Layer s、Optimizer s、Variable s 等)之間的依賴關係圖,該圖用於在恢複檢查點時匹配變量。它可以對 Python 程序中的更改更加健壯,並有助於在即刻執行時支持變量的restore-on-create。對於新代碼,首選 tf.train.Checkpoint 而不是 tf.compat.v1.train.Saver

Checkpoint 對象依賴於作為關鍵字參數傳遞給其構造函數的對象,並且每個依賴項的名稱都與創建它的關鍵字參數的名稱相同。 LayerOptimizer 等 TensorFlow 類將自動添加對其變量的依賴項(例如 "kernel" 和 "bias" 用於 tf.keras.layers.Dense )。從tf.keras.Model 繼承使得在用戶定義的類中管理依賴關係變得容易,因為Model 掛鉤到屬性分配。例如:

class Regress(tf.keras.Model):

  def __init__(self):
    super(Regress, self).__init__()
    self.input_transform = tf.keras.layers.Dense(10)
    # ...

  def call(self, inputs):
    x = self.input_transform(inputs)
    # ...

這個Model 在它的Dense 層上有一個名為"input_transform" 的依賴關係,它又依賴於它的變量。因此,使用tf.train.Checkpoint 保存Regress 的實例也將保存Dense 層創建的所有變量。

當變量分配給多個工作人員時,每個工作人員都會編寫自己的檢查點部分。然後將這些部分合並/重新索引以充當單個檢查點。這避免了將所有變量複製到一個工作人員,但確實要求所有工作人員都看到一個公共文件係統。

雖然 tf.keras.Model.save_weightstf.train.Checkpoint.save 以相同的格式保存,但請注意生成的檢查點的根是保存方法附加到的對象。這意味著使用 save_weights 保存 tf.keras.Model 並加載到帶有 Modeltf.train.Checkpoint (反之亦然)將不匹配 Model 的變量。有關詳細信息,請參閱訓練檢查點指南。首選 tf.train.Checkpoint 而不是 tf.keras.Model.save_weights 用於訓練檢查點。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.train.Checkpoint。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。