Model
將圖層分組為具有訓練和推理函數的對象。
用法
tf.keras.Model(
*args, **kwargs
)
參數
-
inputs
模型的輸入:keras.Input
對象或keras.Input
對象列表。 -
outputs
模型的輸出。請參閱下麵的函數 API 示例。 -
name
字符串,模型的名稱。
屬性
-
distribute_strategy
該模型是在tf.distribute.Strategy
下創建的。 -
layers
-
metrics_names
返回所有輸出的模型顯示標簽。注意:
metrics_names
僅在keras.Model
已根據實際數據進行訓練/評估後可用。inputs = tf.keras.layers.Input(shape=(3,)) outputs = tf.keras.layers.Dense(2)(inputs) model = tf.keras.models.Model(inputs=inputs, outputs=outputs) model.compile(optimizer="Adam", loss="mse", metrics=["mae"]) model.metrics_names []
x = np.random.random((2, 3)) y = np.random.randint(0, 2, (2, 2)) model.fit(x, y) model.metrics_names ['loss', 'mae']
inputs = tf.keras.layers.Input(shape=(3,)) d = tf.keras.layers.Dense(2, name='out') output_1 = d(inputs) output_2 = d(inputs) model = tf.keras.models.Model( inputs=inputs, outputs=[output_1, output_2]) model.compile(optimizer="Adam", loss="mse", metrics=["mae", "acc"]) model.fit(x, (y, y)) model.metrics_names ['loss', 'out_loss', 'out_1_loss', 'out_mae', 'out_acc', 'out_1_mae', 'out_1_acc']
-
run_eagerly
指示模型是否應立即運行的可設置屬性。即刻地運行意味著您的模型將像 Python 代碼一樣逐步運行。您的模型可能會運行得更慢,但通過單步調用各個層調用,您應該可以更輕鬆地對其進行調試。
默認情況下,我們將嘗試將您的模型編譯為靜態圖以提供最佳執行性能。
有兩種方法可以實例化 Model
:
1 - 使用 "Functional API",從 Input
開始,鏈接層調用以指定模型的前向傳遞,最後從輸入和輸出創建模型:
import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
注意:僅支持輸入張量的字典、列表和元組。不支持嵌套輸入(例如列表列表或 dict 的 dicts)。
也可以使用中間張量創建新的函數 API 模型。這使您能夠快速提取模型的sub-components。
例子:
inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)
full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)
請注意,backbone
和 activations
模型不是使用 keras.Input
對象創建的,而是使用源自 keras.Inputs
對象的張量創建的。在底層,這些模型將共享層和權重,以便用戶可以訓練 full_model
,並使用 backbone
或 activations
進行特征提取。模型的輸入和輸出也可以是張量的嵌套結構,創建的模型是標準的函數 API 模型,支持所有現有的 API。
2 - 通過繼承 Model
類:在這種情況下,您應該在 __init__()
中定義您的層,並且您應該在 call()
中實現模型的前向傳遞。
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MyModel()
如果您將 Model
子類化,則可以選擇在 call()
中有一個 training
參數(布爾值),您可以使用它來指定訓練和推理中的不同行為:
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.dropout = tf.keras.layers.Dropout(0.5)
def call(self, inputs, training=False):
x = self.dense1(inputs)
if training:
x = self.dropout(x, training=training)
return self.dense2(x)
model = MyModel()
創建模型後,您可以使用 model.compile()
為模型配置損失和指標,使用 model.fit()
訓練模型,或使用 model.predict()
使用模型進行預測。
相關用法
- Python tf.keras.Model.compute_loss用法及代碼示例
- Python tf.keras.Model.reset_metrics用法及代碼示例
- Python tf.keras.Model.compile用法及代碼示例
- Python tf.keras.Model.save_spec用法及代碼示例
- Python tf.keras.Model.save用法及代碼示例
- Python tf.keras.Model.compute_metrics用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
- Python tf.keras.layers.InputLayer用法及代碼示例
- Python tf.keras.callbacks.ReduceLROnPlateau用法及代碼示例
- Python tf.keras.layers.serialize用法及代碼示例
- Python tf.keras.metrics.Hinge用法及代碼示例
- Python tf.keras.experimental.WideDeepModel.compute_loss用法及代碼示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代碼示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代碼示例
- Python tf.keras.applications.resnet50.preprocess_input用法及代碼示例
- Python tf.keras.metrics.SparseCategoricalCrossentropy.merge_state用法及代碼示例
- Python tf.keras.metrics.sparse_categorical_accuracy用法及代碼示例
- Python tf.keras.layers.Dropout用法及代碼示例
- Python tf.keras.activations.softplus用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.Model。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。