将模型另存为 TensorFlow SavedModel 或 HDF5 文件。
用法
tf.keras.models.save_model(
model, filepath, overwrite=True, include_optimizer=True, save_format=None,
signatures=None, options=None, save_traces=True
)
参数
-
model
要保存的 Keras 模型实例。 -
filepath
以下之一:- 字符串或
pathlib.Path
对象,保存模型的路径 h5py.File
对象保存模型的位置
- 字符串或
-
overwrite
我们是否应该覆盖目标位置的任何现有模型,或者通过手动提示询问用户。 -
include_optimizer
如果为 True,则将优化器的状态一起保存。 -
save_format
'tf' 或'h5',表示是将模型保存到 Tensorflow SavedModel 还是 HDF5。在 TF 2.X 中默认为 'tf',在 TF 1.X 中默认为 'h5'。 -
signatures
使用 SavedModel 保存的签名。仅适用于'tf' 格式。有关详细信息,请参阅tf.saved_model.save
中的signatures
参数。 -
options
(仅适用于 SavedModel 格式)tf.saved_model.SaveOptions
对象,指定保存到 SavedModel 的选项。 -
save_traces
(仅适用于 SavedModel 格式)启用后,SavedModel 将存储每一层的函数轨迹。这可以禁用,以便只存储每一层的配置。默认为True
。禁用此函数将减少序列化时间并减小文件大小,但它要求所有自定义层/模型都实现get_config()
方法。
抛出
-
ImportError
如果保存格式为 hdf5,则 h5py 不可用。
有关详细信息,请参阅序列化和保存指南。
用法:
model = tf.keras.Sequential([
tf.keras.layers.Dense(5, input_shape=(3,)),
tf.keras.layers.Softmax()])
model.save('/tmp/model')
loaded_model = tf.keras.models.load_model('/tmp/model')
x = tf.random.uniform((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
SavedModel 和 HDF5 文件包含:
- 模型的配置(拓扑)
- 模型的权重
- model's optimizer's 状态(如果有)
因此,模型可以在完全相同的状态下重新实例化,而无需任何用于模型定义或训练的代码。
请注意,模型权重在加载后可能具有不同的作用域名称。范围名称包括模型/层名称,例如 "dense_1/kernel:0"
。建议您使用图层属性来访问特定变量,例如model.get_layer("dense_1").kernel
。
SavedModel 序列化格式
Keras SavedModel 使用tf.saved_model.save
保存模型和附加到模型的所有可跟踪对象(例如层和变量)。模型配置、权重和优化器保存在 SavedModel 中。此外,对于附加到模型的每个 Keras 层,SavedModel 存储:
- 配置和元数据——例如名称、数据类型、可训练状态
- 跟踪调用和损失函数,它们存储为 TensorFlow 子图。
跟踪函数允许 SavedModel 格式在没有原始类定义的情况下保存和加载自定义层。
您可以通过禁用save_traces
选项来选择不保存跟踪的函数。这将减少保存模型所需的时间以及输出 SavedModel 占用的磁盘空间量。如果启用此选项,则必须在加载模型时提供所有自定义类定义。请参阅 tf.keras.models.load_model
中的 custom_objects
参数。
相关用法
- Python tf.keras.models.clone_model用法及代码示例
- Python tf.keras.models.model_from_json用法及代码示例
- Python tf.keras.models.load_model用法及代码示例
- Python tf.keras.models.model_from_config用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.metrics.Hinge用法及代码示例
- Python tf.keras.metrics.SparseCategoricalAccuracy.merge_state用法及代码示例
- Python tf.keras.metrics.RootMeanSquaredError用法及代码示例
- Python tf.keras.metrics.SparseCategoricalCrossentropy.merge_state用法及代码示例
- Python tf.keras.metrics.sparse_categorical_accuracy用法及代码示例
- Python tf.keras.metrics.FalseNegatives用法及代码示例
- Python tf.keras.metrics.TrueNegatives用法及代码示例
- Python tf.keras.metrics.RecallAtPrecision.merge_state用法及代码示例
- Python tf.keras.metrics.SpecificityAtSensitivity用法及代码示例
- Python tf.keras.metrics.Mean用法及代码示例
- Python tf.keras.metrics.poisson用法及代码示例
- Python tf.keras.metrics.LogCoshError用法及代码示例
- Python tf.keras.metrics.MeanSquaredLogarithmicError用法及代码示例
- Python tf.keras.metrics.FalsePositives.merge_state用法及代码示例
- Python tf.keras.metrics.OneHotMeanIoU.merge_state用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.models.save_model。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。