导出 tf.Module(和子类)obj
到SavedModel 格式.
用法
tf.saved_model.save(
obj, export_dir, signatures=None, options=None
)
参数
-
obj
要导出的可跟踪对象(例如 tf.Module 或 tf.train.Checkpoint)。 -
export_dir
用于编写 SavedModel 的目录。 -
signatures
可选,三种类型之一:- 带有指定输入签名的
tf.function
,它将使用默认的服务签名 key , f.get_concrete_function
在@tf.function
装饰函数f
上的结果,在这种情况下,f
将用于在默认服务签名 key 下为 SavedModel 生成签名,- 一个字典,它将签名键映射到带有输入签名或具体函数的
tf.function
实例。这种字典的键可以是任意字符串,但通常来自tf.saved_model.signature_constants
模块。
- 带有指定输入签名的
-
options
tf.saved_model.SaveOptions
对象用于配置保存选项。
抛出
-
ValueError
如果obj
不可追踪。
obj
必须继承自 Trackable 类。
示例用法:
class Adder(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def add(self, x):
return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')
然后,生成的 SavedModel 可以使用名为 "x" 的输入进行服务,这是一个 dtype 为 float32 的标量。
签名
签名定义计算的输入和输出类型。可选的 save signatures
参数控制 obj
中的哪些方法可用于消耗 SavedModel
的程序,例如服务 API。 Python 函数可以用 @tf.function(input_signature=...)
装饰并作为签名直接传递,或者在用 @tf.function
装饰的方法上懒惰地调用 get_concrete_function
。
例子:
class Adder(tf.Module):
@tf.function
def add(self, x):
return x + x
model = Adder()
tf.saved_model.save(
model, '/tmp/adder',signatures=model.add.get_concrete_function(
tf.TensorSpec([], tf.float32)))
如果 @tf.function
没有输入签名并且没有在该方法上调用 get_concrete_function
,则该函数将无法在恢复的 SavedModel 中直接调用。
例子:
class Adder(tf.Module):
@tf.function
def add(self, x):
return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')
restored = tf.saved_model.load('/tmp/adder')
restored.add(1.)
Traceback (most recent call last):
ValueError:Found zero restored functions for caller function.
如果省略signatures
参数,则将在obj
中搜索@tf.function
修饰的方法。如果恰好找到一个跟踪的@tf.function
,则该方法将用作 SavedModel 的默认签名。否则,任何附加到 obj
的 @tf.function
或其依赖项都将被导出以用于 tf.saved_model.load
。
在导出的 SavedModel 中调用签名时,Tensor
参数由名称标识。默认情况下,这些名称将来自 Python 函数的参数名称。它们可以通过在相应的 tf.TensorSpec
对象中指定 name=...
参数来覆盖。如果多个 Tensor
通过单个参数传递给 Python 函数,则需要显式命名。
用作 signatures
的函数的输出必须是平面列表,在这种情况下输出将被编号,或者是映射字符串键到 Tensor
的字典,在这种情况下,键将用于命名输出。
签名在 tf.saved_model.load
作为 .signatures
属性返回的对象中可用。这是一个保留属性:具有自定义 .signatures
属性的对象上的 tf.saved_model.save
将引发异常。
_将 tf.savedmodel.save
与 Keras 模型一起使用
虽然 Keras 有自己的保存和加载 API,但此函数可用于导出 Keras 模型。例如,使用指定的签名导出:
class Adder(tf.keras.Model):
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def concat(self, x):
return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')
从没有固定签名的函数导出:
class Adder(tf.keras.Model):
@tf.function
def concat(self, x):
return x + x
model = Adder()
tf.saved_model.save(
model, '/tmp/adder',
signatures=model.concat.get_concrete_function(
tf.TensorSpec(shape=[], dtype=tf.string, name="string_input")))
从输入和输出构造的 tf.keras.Model
实例已经具有签名,因此不需要 @tf.function
装饰器或 signatures
参数。如果两者均未指定,则导出模型的正向传递。
x = tf.keras.layers.Input((4,), name="x")
y = tf.keras.layers.Dense(5, name="out")(x)
model = tf.keras.Model(x, y)
tf.saved_model.save(model, '/tmp/saved_model/')
导出的 SavedModel 采用形状为 [None, 4] 的 "x" 并返回形状为 [None, 5] 的 "out"
变量和检查点
必须通过将变量分配给被跟踪对象的属性或直接分配给obj
的属性来跟踪变量。 TensorFlow 对象(例如来自 tf.keras.layers
的层,来自 tf.train
的优化器)自动跟踪它们的变量。这与tf.train.Checkpoint
使用的跟踪方案相同,导出的Checkpoint
对象可以通过将tf.train.Checkpoint.restore
指向SavedModel 的"variables/" 子目录来恢复为训练检查点。
tf.function
不使用函数体外部的hard-code 设备注释,而不是使用调用上下文的设备。这意味着例如导出在 GPU 上运行的模型并在 CPU 上提供服务通常可以工作,但有一些异常:
- 函数体内的
tf.device
注解在导出的模型中将是hard-coded;不鼓励使用这种类型的注释。 - Device-specific 操作,例如名称中带有 "cuDNN" 或带有 device-specific 布局可能会导致问题。
- 对于
ConcreteFunctions
,主动分布策略将导致设备放置在函数中为 hard-coded。
导出的 SavedModelstf.saved_model.save 剥离 default-valued 属性自动,当 SavedModel 的消费者运行的 TensorFlow 版本比生产者更旧时,它会消除一个不兼容的来源。然而,还有其他无法自动处理的不兼容来源,例如当导出的模型包含消费者没有定义的操作时。
eager模式兼容性
图形构建时没有得到很好的支持。从 TensorFlow 1.x 开始,tf.compat.v1.enable_eager_execution()
应该首先运行。从 TensorFlow 1.x 构建图形时,在循环中调用 tf.saved_model.save 将在每次迭代时向默认图形添加新的保存操作。
不能从函数体内调用。
相关用法
- Python tf.saved_model.load用法及代码示例
- Python tf.saved_model.Asset用法及代码示例
- Python tf.saved_model.SaveOptions用法及代码示例
- Python tf.saved_model.experimental.TrackableResource用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.strings.substr用法及代码示例
- Python tf.strings.reduce_join用法及代码示例
- Python tf.sparse.cross用法及代码示例
- Python tf.sparse.mask用法及代码示例
- Python tf.strings.regex_full_match用法及代码示例
- Python tf.sparse.split用法及代码示例
- Python tf.strings.regex_replace用法及代码示例
- Python tf.signal.overlap_and_add用法及代码示例
- Python tf.strings.length用法及代码示例
- Python tf.strided_slice用法及代码示例
- Python tf.sparse.to_dense用法及代码示例
- Python tf.strings.bytes_split用法及代码示例
- Python tf.summary.text用法及代码示例
- Python tf.shape用法及代码示例
- Python tf.sparse.expand_dims用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.saved_model.save。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。