当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.models.load_model用法及代码示例


加载通过 model.save() 保存的模型。

用法

tf.keras.models.load_model(
    filepath, custom_objects=None, compile=True, options=None
)

参数

  • filepath 以下之一:
    • 字符串或pathlib.Path 对象,保存模型的路径
    • h5py.File 从中加载模型的对象
  • custom_objects 可选字典映射名称(字符串)到反序列化期间要考虑的自定义类或函数。
  • compile 布尔值,加载后是否编译模型。
  • options 可选的 tf.saved_model.LoadOptions 对象,指定从 SavedModel 加载的选项。

返回

  • 一个 Keras 模型实例。如果原始模型已编译并使用优化器保存,则将编译返回的模型。否则,模型将未编译。在返回未编译模型的情况下,如果 compile 参数设置为 True ,则会显示警告。

抛出

  • ImportError 如果从 hdf5 文件加载并且 h5py 不可用。
  • IOError 如果保存文件无效。

用法:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(5, input_shape=(3,)),
    tf.keras.layers.Softmax()])
model.save('/tmp/model')
loaded_model = tf.keras.models.load_model('/tmp/model')
x = tf.random.uniform((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))

请注意,模型权重在加载后可能具有不同的作用域名称。范围名称包括模型/层名称,例如 "dense_1/kernel:0" 。建议您使用图层属性来访问特定变量,例如model.get_layer("dense_1").kernel

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.models.load_model。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。