當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.train.init_from_checkpoint用法及代碼示例


替換 tf.Variable 初始化程序,以便它們從檢查點文件加載。

用法

tf.compat.v1.train.init_from_checkpoint(
    ckpt_dir_or_file, assignment_map
)

參數

  • ckpt_dir_or_file 帶有檢查點文件或檢查點路徑的目錄。
  • assignment_map 字典或鍵值對列表,其中鍵是檢查點中變量的名稱,值是當前變量或當前變量的名稱(在默認圖中)。

拋出

  • ValueError 如果當前圖中缺少變量,或者檢查點中缺少檢查點或張量。

遷移到 TF2

警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。

tf.compat.v1.train.init_from_checkpoint 不推薦用於恢複 TF2 中的變量值。

要恢複 TF2 中的檢查點,請使用 tf.keras.Model.load_weightstf.train.Checkpoint.restore 。這些 API 使用基於對象的檢查點方法,而 tf.compat.v1.init_from_checkpoint 依賴於基於 more-fragile variable-name 的檢查點方法。 TF2 中沒有基於對象的等效項init_from_checkpoint

請立即使用基於對象的 API 重寫您的檢查點,有關詳細信息,請參閱遷移指南。

您可以使用 tf.train.Checkpoint.restoretf.keras.Model.load_weights 加載由 tf.compat.v1.train.Saver 編寫的基於名稱的檢查點。但是,您可能必須更改模型中的變量名稱以匹配基於名稱的檢查點中的變量名稱,可以使用 tf.train.list_variables(path) 查看。

另一種選擇是創建一個assignment_map,將基於名稱的檢查點中的變量名稱映射到模型中的變量,例如:

{
    'sequential/dense/bias':model.variables[0],
    'sequential/dense/kernel':model.variables[1]
}

並使用tf.compat.v1.train.init_from_checkpoint恢複基於名稱的檢查點。

恢複後,使用 tf.train.Checkpoint.savetf.keras.Model.save_weights 重新編碼您的檢查點。

值不會立即加載,而是在初始化程序運行時加載(通常通過運行 tf.compat.v1.global_variables_initializer 操作)。

注意:這會覆蓋指定變量的默認初始化操作並重新定義 dtype。

分配映射支持以下語法:

  • 'checkpoint_scope_name/':'scope_name/' - 將從checkpoint_scope_name 加載當前scope_name 中具有匹配張量名稱的所有變量。
  • 'checkpoint_scope_name/some_other_variable':'scope_name/variable_name' - 將從 checkpoint_scope_name/some_other_variable 初始化 scope_name/variable_name 變量。
  • 'scope_variable_name':variable - 將從檢查點使用張量 'scope_variable_name' 初始化給定的 tf.Variable 對象。
  • 'scope_variable_name':list(variable) - 將從檢查點使用張量 'scope_variable_name' 初始化分區變量列表。
  • '/':'scope_name/' - 將從檢查點的根目錄加載當前 scope_name 中的所有變量(例如,無範圍)。

支持加載到分區變量中,表示為 '<variable>/part_<part #>'

分配映射可以是字典,也可以是對列表。後者對於從檢查點中的同一變量初始化當前圖中的多個變量是必要的。

例子:

# Say, '/tmp/model.ckpt' has the following tensors:
#  -- name='old_scope_1/var1', shape=[20, 2]
#  -- name='old_scope_1/var2', shape=[50, 4]
#  -- name='old_scope_2/var3', shape=[100, 100]

# Create new model's variables
with tf.compat.v1.variable_scope('new_scope_1'):
  var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
                         initializer=tf.compat.v1.zeros_initializer())
with tf.compat.v1.variable_scope('new_scope_2'):
  var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
                         initializer=tf.compat.v1.zeros_initializer())
  # Partition into 5 variables along the first axis.
  var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
                         initializer=tf.compat.v1.zeros_initializer(),
                         partitioner=lambda shape, dtype:[5, 1])

# Initialize all variables in `new_scope_1` from `old_scope_1`.
init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/':'new_scope_1/'})

# Use names to specify which variables to initialize from checkpoint.
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_1/var1':'new_scope_1/var1',
                      'old_scope_1/var2':'new_scope_2/var2'})

# Or use tf.Variable objects to identify what to initialize.
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_1/var1':var1,
                      'old_scope_1/var2':var2})

# Initialize partitioned variables using variable's name
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_2/var3':'new_scope_2/var3'})

# Or specify the list of tf.Variable objects.
init_from_checkpoint('/tmp/model.ckpt',
                     {'old_scope_2/var3':var3._get_variable_list()})

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.train.init_from_checkpoint。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。