將模型另存為 TensorFlow SavedModel 或 HDF5 文件。
用法
tf.keras.models.save_model(
model, filepath, overwrite=True, include_optimizer=True, save_format=None,
signatures=None, options=None, save_traces=True
)
參數
-
model
要保存的 Keras 模型實例。 -
filepath
以下之一:- 字符串或
pathlib.Path
對象,保存模型的路徑 h5py.File
對象保存模型的位置
- 字符串或
-
overwrite
我們是否應該覆蓋目標位置的任何現有模型,或者通過手動提示詢問用戶。 -
include_optimizer
如果為 True,則將優化器的狀態一起保存。 -
save_format
'tf' 或'h5',表示是將模型保存到 Tensorflow SavedModel 還是 HDF5。在 TF 2.X 中默認為 'tf',在 TF 1.X 中默認為 'h5'。 -
signatures
使用 SavedModel 保存的簽名。僅適用於'tf' 格式。有關詳細信息,請參閱tf.saved_model.save
中的signatures
參數。 -
options
(僅適用於 SavedModel 格式)tf.saved_model.SaveOptions
對象,指定保存到 SavedModel 的選項。 -
save_traces
(僅適用於 SavedModel 格式)啟用後,SavedModel 將存儲每一層的函數軌跡。這可以禁用,以便隻存儲每一層的配置。默認為True
。禁用此函數將減少序列化時間並減小文件大小,但它要求所有自定義層/模型都實現get_config()
方法。
拋出
-
ImportError
如果保存格式為 hdf5,則 h5py 不可用。
有關詳細信息,請參閱序列化和保存指南。
用法:
model = tf.keras.Sequential([
tf.keras.layers.Dense(5, input_shape=(3,)),
tf.keras.layers.Softmax()])
model.save('/tmp/model')
loaded_model = tf.keras.models.load_model('/tmp/model')
x = tf.random.uniform((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
SavedModel 和 HDF5 文件包含:
- 模型的配置(拓撲)
- 模型的權重
- model's optimizer's 狀態(如果有)
因此,模型可以在完全相同的狀態下重新實例化,而無需任何用於模型定義或訓練的代碼。
請注意,模型權重在加載後可能具有不同的作用域名稱。範圍名稱包括模型/層名稱,例如 "dense_1/kernel:0"
。建議您使用圖層屬性來訪問特定變量,例如model.get_layer("dense_1").kernel
。
SavedModel 序列化格式
Keras SavedModel 使用tf.saved_model.save
保存模型和附加到模型的所有可跟蹤對象(例如層和變量)。模型配置、權重和優化器保存在 SavedModel 中。此外,對於附加到模型的每個 Keras 層,SavedModel 存儲:
- 配置和元數據——例如名稱、數據類型、可訓練狀態
- 跟蹤調用和損失函數,它們存儲為 TensorFlow 子圖。
跟蹤函數允許 SavedModel 格式在沒有原始類定義的情況下保存和加載自定義層。
您可以通過禁用save_traces
選項來選擇不保存跟蹤的函數。這將減少保存模型所需的時間以及輸出 SavedModel 占用的磁盤空間量。如果啟用此選項,則必須在加載模型時提供所有自定義類定義。請參閱 tf.keras.models.load_model
中的 custom_objects
參數。
相關用法
- Python tf.keras.models.clone_model用法及代碼示例
- Python tf.keras.models.model_from_json用法及代碼示例
- Python tf.keras.models.load_model用法及代碼示例
- Python tf.keras.models.model_from_config用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
- Python tf.keras.metrics.Hinge用法及代碼示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代碼示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代碼示例
- Python tf.keras.metrics.SparseCategoricalCrossentropy.merge_state用法及代碼示例
- Python tf.keras.metrics.sparse_categorical_accuracy用法及代碼示例
- Python tf.keras.metrics.FalseNegatives用法及代碼示例
- Python tf.keras.metrics.TrueNegatives用法及代碼示例
- Python tf.keras.metrics.RecallAtPrecision.merge_state用法及代碼示例
- Python tf.keras.metrics.SpecificityAtSensitivity用法及代碼示例
- Python tf.keras.metrics.Mean用法及代碼示例
- Python tf.keras.metrics.poisson用法及代碼示例
- Python tf.keras.metrics.LogCoshError用法及代碼示例
- Python tf.keras.metrics.MeanSquaredLogarithmicError用法及代碼示例
- Python tf.keras.metrics.FalsePositives.merge_state用法及代碼示例
- Python tf.keras.metrics.OneHotMeanIoU.merge_state用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.models.save_model。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。