tf.estimator.Estimators
中的熱啟動設置。
用法
tf.estimator.WarmStartSettings(
ckpt_to_initialize_from, vars_to_warm_start='.*',
var_name_to_vocab_info=None, var_name_to_prev_var_name=None
)
屬性
-
ckpt_to_initialize_from
[必需] 一個字符串,指定帶有檢查點文件的目錄或檢查點的路徑,從該目錄中熱啟動模型參數。 -
vars_to_warm_start
[可選] 以下之一:捕獲要熱啟動的變量的正則表達式(字符串)(請參閱 tf.compat.v1.get_collection)。此表達式將僅考慮 TRAINABLE_VARIABLES 集合中的變量——如果您需要熱啟動 non_TRAINABLE 變量(例如優化器累加器或批量標準統計信息),請使用以下選項。
字符串列表,每個字符串都是通過GLOBAL_VARIABLES 提供給 tf.compat.v1.get_collection 的正則表達式範圍(請參閱 tf.compat.v1.get_collection)。出於向後兼容性的原因,這與 single-string 參數類型是分開的。
要熱啟動的變量列表。如果您無權訪問調用站點的
Variable
對象,請使用上述選項。None
,在這種情況下,隻有在var_name_to_vocab_info
中指定的 TRAINABLE 變量將被熱啟動。
默認為
'.*'
,其中 warm-starts TRAINABLE_VARIABLES 集合中的所有變量。請注意,這不包括變量,例如累加器和批量標準的移動統計信息。 -
var_name_to_vocab_info
[可選]tf.estimator.VocabInfo
的變量名稱(字符串)字典。變量名稱應該是"full" 變量,而不是分區的名稱。如果未明確提供,則假定該變量沒有(更改)詞匯表。 -
var_name_to_prev_var_name
[可選] 變量名稱(字符串)的字典到ckpt_to_initialize_from
中的 previously-trained 變量的名稱。如果未明確提供,則假定變量的名稱在先前的檢查點和當前模型之間是相同的。請注意,這對熱啟動的變量集沒有影響,並且僅控製名稱映射(使用vars_to_warm_start
來控製要熱啟動的變量)。
與罐裝 tf.estimator.DNNEstimator
一起使用的示例:
emb_vocab_file = tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_vocabulary_file(
"sc_vocab_file", "new_vocab.txt", vocab_size=100),
dimension=8)
emb_vocab_list = tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_vocabulary_list(
"sc_vocab_list", vocabulary_list=["a", "b"]),
dimension=8)
estimator = tf.estimator.DNNClassifier(
hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],
warm_start_from=ws)
其中ws
可以定義為:
熱啟動模型中的所有權重(輸入層和隱藏權重)。可以提供目錄或特定檢查點(在前者的情況下,將使用最新的檢查點):
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp")
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp/model-1000")
僅熱啟動嵌入(輸入層):
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
vars_to_warm_start=".*input_layer.*")
熱啟動所有權重,但對應於 sc_vocab_file
的嵌入參數與當前模型中使用的詞匯不同:
vocab_info = tf.estimator.VocabInfo(
new_vocab=sc_vocab_file.vocabulary_file,
new_vocab_size=sc_vocab_file.vocabulary_size,
num_oov_buckets=sc_vocab_file.num_oov_buckets,
old_vocab="old_vocab.txt"
)
ws = WarmStartSettings(
ckpt_to_initialize_from="/tmp",
var_name_to_vocab_info={
"input_layer/sc_vocab_file_embedding/embedding_weights":vocab_info
})
僅熱啟動 sc_vocab_file
嵌入(沒有其他變量),其詞匯與當前模型中使用的詞匯不同:
vocab_info = tf.estimator.VocabInfo(
new_vocab=sc_vocab_file.vocabulary_file,
new_vocab_size=sc_vocab_file.vocabulary_size,
num_oov_buckets=sc_vocab_file.num_oov_buckets,
old_vocab="old_vocab.txt"
)
ws = WarmStartSettings(
ckpt_to_initialize_from="/tmp",
vars_to_warm_start=None,
var_name_to_vocab_info={
"input_layer/sc_vocab_file_embedding/embedding_weights":vocab_info
})
熱啟動所有權重,但對應於 sc_vocab_file
的參數與當前檢查點中使用的詞匯不同,並且僅使用了其中的 100 個條目:
vocab_info = tf.estimator.VocabInfo(
new_vocab=sc_vocab_file.vocabulary_file,
new_vocab_size=sc_vocab_file.vocabulary_size,
num_oov_buckets=sc_vocab_file.num_oov_buckets,
old_vocab="old_vocab.txt",
old_vocab_size=100
)
ws = WarmStartSettings(
ckpt_to_initialize_from="/tmp",
var_name_to_vocab_info={
"input_layer/sc_vocab_file_embedding/embedding_weights":vocab_info
})
熱啟動所有權重,但sc_vocab_file
對應的參數與當前檢查點中使用的參數不同,sc_vocab_list
對應的參數名稱與當前檢查點不同:
vocab_info = tf.estimator.VocabInfo(
new_vocab=sc_vocab_file.vocabulary_file,
new_vocab_size=sc_vocab_file.vocabulary_size,
num_oov_buckets=sc_vocab_file.num_oov_buckets,
old_vocab="old_vocab.txt",
old_vocab_size=100
)
ws = WarmStartSettings(
ckpt_to_initialize_from="/tmp",
var_name_to_vocab_info={
"input_layer/sc_vocab_file_embedding/embedding_weights":vocab_info
},
var_name_to_prev_var_name={
"input_layer/sc_vocab_list_embedding/embedding_weights":
"old_tensor_name"
})
熱啟動所有 TRAINABLE 變量:
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
vars_to_warm_start=".*")
熱啟動所有變量(包括非 TRAINABLE):
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
vars_to_warm_start=[".*"])
熱啟動非 TRAINABLE 變量 "v1"、"v1/Momentum" 和 "v2" 但不是 "v2/momentum":
ws = WarmStartSettings(ckpt_to_initialize_from="/tmp",
vars_to_warm_start=["v1", "v2[^/]"])
相關用法
- Python tf.estimator.TrainSpec用法及代碼示例
- Python tf.estimator.LogisticRegressionHead用法及代碼示例
- Python tf.estimator.MultiHead用法及代碼示例
- Python tf.estimator.PoissonRegressionHead用法及代碼示例
- Python tf.estimator.experimental.stop_if_lower_hook用法及代碼示例
- Python tf.estimator.RunConfig用法及代碼示例
- Python tf.estimator.MultiLabelHead用法及代碼示例
- Python tf.estimator.experimental.stop_if_no_increase_hook用法及代碼示例
- Python tf.estimator.BaselineEstimator用法及代碼示例
- Python tf.estimator.DNNLinearCombinedEstimator用法及代碼示例
- Python tf.estimator.Estimator用法及代碼示例
- Python tf.estimator.experimental.LinearSDCA用法及代碼示例
- Python tf.estimator.experimental.RNNClassifier用法及代碼示例
- Python tf.estimator.experimental.make_early_stopping_hook用法及代碼示例
- Python tf.estimator.LinearRegressor用法及代碼示例
- Python tf.estimator.LinearEstimator用法及代碼示例
- Python tf.estimator.DNNClassifier用法及代碼示例
- Python tf.estimator.BaselineClassifier用法及代碼示例
- Python tf.estimator.experimental.stop_if_higher_hook用法及代碼示例
- Python tf.estimator.train_and_evaluate用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.estimator.WarmStartSettings。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。