當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.train.Checkpoint用法及代碼示例


管理將可跟蹤值保存/恢複到磁盤。

用法

tf.train.Checkpoint(
    root=None, **kwargs
)

參數

  • root 檢查點的根對象。
  • **kwargs 關鍵字參數設置為此對象的屬性,並與檢查點一起保存。值必須是可跟蹤的對象。

拋出

  • ValueError 如果rootkwargs 中的對象不可追蹤。如果 root 對象跟蹤的對象與 kwargs 中屬性中列出的對象不同(例如,root.child = Atf.train.Checkpoint(root, child=B) 不兼容),也會引發 ValueError

屬性

  • save_counter 調用 save() 時增加。用於對檢查點進行編號。

TensorFlow 對象可能包含可跟蹤狀態,例如 tf.Variable s、tf.keras.optimizers.Optimizer implementations、tf.data.Dataset iterators、tf.keras.Layer implementations 或 tf.keras.Model implementations。這些被稱為可跟蹤對象。

可以構造 Checkpoint 對象以將單個或一組可跟蹤對象保存到檢查點文件。它維護一個save_counter 用於對檢查點進行編號。

例子:

model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)

# Save a checkpoint to /tmp/training_checkpoints-{save_counter}. Every time
# checkpoint.save is called, the save counter is increased.
save_path = checkpoint.save('/tmp/training_checkpoints')

# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)

示例 2:

import tensorflow as tf
import os

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

# Create a Checkpoint that will manage two objects with trackable state,
# one we name "optimizer" and the other we name "model".
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Checkpoint.save()Checkpoint.restore() 寫入和讀取基於對象的檢查點,而 TensorFlow 1.x 的 tf.compat.v1.train.Saver 寫入和讀取基於 variable.name 的檢查點。基於對象的檢查點保存帶有命名邊的 Python 對象(Layer s、Optimizer s、Variable s 等)之間的依賴關係圖,該圖用於在恢複檢查點時匹配變量。它可以對 Python 程序中的更改更加健壯,並有助於支持變量的restore-on-create。

Checkpoint 對象依賴於作為關鍵字參數傳遞給其構造函數的對象,並且每個依賴項的名稱都與創建它的關鍵字參數的名稱相同。 LayerOptimizer 等 TensorFlow 類將自動添加對它們自己變量的依賴項(例如 "kernel" 和 "bias" 用於 tf.keras.layers.Dense )。從tf.keras.Model 繼承使得在用戶定義的類中管理依賴關係變得容易,因為Model 掛鉤到屬性分配。例如:

class Regress(tf.keras.Model):

  def __init__(self):
    super(Regress, self).__init__()
    self.input_transform = tf.keras.layers.Dense(10)
    # ...

  def call(self, inputs):
    x = self.input_transform(inputs)
    # ...

這個Model 在它的Dense 層上有一個名為"input_transform" 的依賴關係,它又依賴於它的變量。因此,使用tf.train.Checkpoint 保存Regress 的實例也將保存Dense 層創建的所有變量。

當變量分配給多個工作人員時,每個工作人員都會編寫自己的檢查點部分。然後將這些部分合並/重新索引以充當單個檢查點。這避免了將所有變量複製到一個工作人員,但確實要求所有工作人員都看到一個公共文件係統。

此函數與 Keras Model save_weights 函數略有不同。 tf.keras.Model.save_weights 使用在 filepath 中指定的名稱創建檢查點文件,而 tf.train.Checkpoint 使用 filepath 作為檢查點文件名的前綴對檢查點進行編號。除此之外,model.save_weights()tf.train.Checkpoint(model).save() 是等價的。

有關詳細信息,請參閱訓練檢查點指南。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.train.Checkpoint。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。