用法
compile(
optimizer='rmsprop', loss=None, metrics=None, loss_weights=None,
weighted_metrics=None, run_eagerly=None, steps_per_execution=None,
jit_compile=None, **kwargs
)
參數
-
optimizer
字符串(優化器名稱)或優化器實例。見tf.keras.optimizers
。 -
loss
損失函數。可能是一個字符串(損失函數的名稱),或者一個tf.keras.losses.Loss
實例。請參閱tf.keras.losses
。損失函數是任何帶有簽名loss = fn(y_true, y_pred)
的可調用函數,其中y_true
是真實值,而y_pred
是模型的預測。y_true
應該具有形狀(batch_size, d0, .. dN)
(在稀疏損失函數的情況下,例如稀疏分類交叉熵,它需要形狀為(batch_size, d0, .. dN-1)
的整數數組)。y_pred
應該具有形狀(batch_size, d0, .. dN)
。損失函數應該返回一個浮點張量。如果使用自定義Loss
實例並將縮減設置為None
,則返回值具有形狀(batch_size, d0, .. dN-1)
即 per-sample 或 per-timestep 損失值;否則,它是一個標量。如果模型有多個輸出,您可以通過傳遞字典或損失列表對每個輸出使用不同的損失。除非指定了loss_weights
,否則模型將最小化的損失值將是所有單個損失的總和。 -
metrics
模型在訓練和測試期間要評估的指標列表。每個都可以是字符串(內置函數的名稱)、函數或tf.keras.metrics.Metric
實例。請參閱tf.keras.metrics
。通常,您將使用metrics=['accuracy']
。函數是任何帶有簽名result = fn(y_true, y_pred)
的可調用函數。要為 multi-output 模型的不同輸出指定不同的指標,您還可以傳遞字典,例如metrics={'output_a':'accuracy', 'output_b':['accuracy', 'mse']}
。您還可以傳遞一個列表來為每個輸出指定一個指標或指標列表,例如metrics=[['accuracy'], ['accuracy', 'mse']]
或metrics=['accuracy', ['accuracy', 'mse']]
。當您傳遞字符串 'accuracy' 或 'acc' 時,我們會根據使用的損失函數和模型輸出形狀將其轉換為tf.keras.metrics.BinaryAccuracy
、tf.keras.metrics.CategoricalAccuracy
、tf.keras.metrics.SparseCategoricalAccuracy
之一。我們也對字符串'crossentropy' 和'ce' 進行類似的轉換。 -
loss_weights
可選的列表或字典,指定標量係數(Python 浮點數)以加權不同模型輸出的損失貢獻。模型將最小化的損失值將是加權和所有個人損失,加權loss_weights
係數。如果是列表,則預計與模型的輸出有 1:1 的映射關係。如果是 dict,則應將輸出名稱(字符串)映射到標量係數。 -
weighted_metrics
在訓練和測試期間由sample_weight
或class_weight
評估和加權的指標列表。 -
run_eagerly
布爾。默認為False
。如果True
,這個Model
的邏輯將不會被包裝在tf.function
中。建議將其保留為None
,除非您的Model
不能在tf.function
內運行。使用tf.distribute.experimental.ParameterServerStrategy
時不支持run_eagerly=True
。 -
steps_per_execution
Int. 默認為 1。每次tf.function
調用期間要運行的批次數。在單個tf.function
調用中運行多個批處理可以極大地提高 TPU 或具有大量 Python 開銷的小型模型的性能。每次執行最多將運行一個完整的 epoch。如果傳遞的數字大於 epoch 的大小,則執行將被截斷為 epoch 的大小。請注意,如果steps_per_execution
設置為N
,則Callback.on_batch_begin
和Callback.on_batch_end
方法將僅在每個N
批次(即在每個tf.function
執行之前/之後)調用。 -
jit_compile
如果True
,用 XLA 編譯模型訓練步驟。XLA是機器學習的優化編譯器。jit_compile
默認情況下未啟用。此選項無法啟用run_eagerly=True
.注意jit_compile=True
不一定適用於所有型號。有關支持的操作的更多信息,請參閱XLA 文檔.另請參閱已知的 XLA 問題更多細節。 -
**kwargs
僅支持向後兼容的參數。
配置模型進行訓練。
例子:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryAccuracy(),
tf.keras.metrics.FalseNegatives()])
相關用法
- Python tf.keras.Model.compute_loss用法及代碼示例
- Python tf.keras.Model.compute_metrics用法及代碼示例
- Python tf.keras.Model.reset_metrics用法及代碼示例
- Python tf.keras.Model.save_spec用法及代碼示例
- Python tf.keras.Model.save用法及代碼示例
- Python tf.keras.Model用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
- Python tf.keras.layers.InputLayer用法及代碼示例
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代碼示例
- Python tf.keras.layers.serialize用法及代碼示例
- Python tf.keras.metrics.Hinge用法及代碼示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代碼示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代碼示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代碼示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代碼示例
- Python tf.keras.metrics.SparseCategoricalCrossentropy.merge_state用法及代碼示例
- Python tf.keras.metrics.sparse_categorical_accuracy用法及代碼示例
- Python tf.keras.layers.Dropout用法及代碼示例
- Python tf.keras.activations.softplus用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.Model.compile。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。