从 export_dir
加载 SavedModel。
用法
tf.saved_model.load(
export_dir, tags=None, options=None
)
参数
-
export_dir
要从中加载的 SavedModel 目录。 -
tags
标识要加载的 MetaGraph 的标签或标签序列。如果 SavedModel 包含单个 MetaGraph,则为可选,如从tf.saved_model.save
导出的那些。 -
options
tf.saved_model.LoadOptions
指定加载选项的对象。
返回
-
具有从签名键映射到函数的
signatures
属性的可跟踪对象。如果 SavedModel 由tf.saved_model.save
导出,它还指向已保存的可跟踪对象、函数、调试信息。
抛出
-
ValueError
如果tags
与 SavedModel 中的 MetaGraph 不匹配。
与 SavedModel 关联的签名可作为函数使用:
imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x=tf.constant([[1.]])))
使用 tf.saved_model.save
导出的对象还具有分配给属性的可跟踪对象和函数:
exported = tf.train.Checkpoint(v=tf.Variable(3.))
exported.f = tf.function(
lambda x:exported.v * x,
input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
tf.saved_model.save(exported, path)
imported = tf.saved_model.load(path)
assert 3. == imported.v.numpy()
assert 6. == imported.f(x=tf.constant(2.)).numpy()
加载 Keras 模型
Keras 模型是可跟踪的,因此可以将它们保存到 SavedModel。 tf.saved_model.load
返回的对象不是 Keras 对象(即没有 .fit
, .predict
等方法)。一些属性和函数仍然可用:.variables
, .trainable_variables
和 .__call__
。
model = tf.keras.Model(...)
tf.saved_model.save(model, path)
imported = tf.saved_model.load(path)
outputs = imported(inputs)
使用tf.keras.models.load_model
恢复 Keras 模型。
从 TensorFlow 1.x 导入 SavedModel
tf.estimator.Estimator
或 1.x SavedModel API 的 SavedModel 具有平面图,而不是 tf.function
对象。这些 SavedModel 将加载以下属性:
.signatures
:将签名名称映射到函数的字典。.prune(feeds, fetches)
:一种允许您为新子图提取函数的方法。这相当于在 TensorFlow 1.x 的会话中导入 SavedModel 并命名提要和提取。imported = tf.saved_model.load(path_to_v1_saved_model) pruned = imported.prune("x:0", "out:0") pruned(tf.ones([]))
有关详细信息,请参阅
tf.compat.v1.wrap_function
。.variables
:导入变量的列表。.graph
:整个导入的图形。.restore(save_path)
:从tf.compat.v1.Saver
保存的检查点恢复变量的函数。
异步使用 SavedModels
当异步使用 SavedModels 时(生产者是一个单独的进程),SavedModel 目录会在所有文件写入之前出现,如果指向不完整的 SavedModel,tf.saved_model.load
将失败。与其检查目录,不如检查"saved_model_dir/saved_model.pb"。该文件作为最后一个tf.saved_model.save
文件操作以原子方式写入。
相关用法
- Python tf.saved_model.Asset用法及代码示例
- Python tf.saved_model.SaveOptions用法及代码示例
- Python tf.saved_model.experimental.TrackableResource用法及代码示例
- Python tf.saved_model.save用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.strings.substr用法及代码示例
- Python tf.strings.reduce_join用法及代码示例
- Python tf.sparse.cross用法及代码示例
- Python tf.sparse.mask用法及代码示例
- Python tf.strings.regex_full_match用法及代码示例
- Python tf.sparse.split用法及代码示例
- Python tf.strings.regex_replace用法及代码示例
- Python tf.signal.overlap_and_add用法及代码示例
- Python tf.strings.length用法及代码示例
- Python tf.strided_slice用法及代码示例
- Python tf.sparse.to_dense用法及代码示例
- Python tf.strings.bytes_split用法及代码示例
- Python tf.summary.text用法及代码示例
- Python tf.shape用法及代码示例
- Python tf.sparse.expand_dims用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.saved_model.load。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。