導出 tf.Module(和子類)obj
到SavedModel 格式.
用法
tf.saved_model.save(
obj, export_dir, signatures=None, options=None
)
參數
-
obj
要導出的可跟蹤對象(例如 tf.Module 或 tf.train.Checkpoint)。 -
export_dir
用於編寫 SavedModel 的目錄。 -
signatures
可選,三種類型之一:- 帶有指定輸入簽名的
tf.function
,它將使用默認的服務簽名 key , f.get_concrete_function
在@tf.function
裝飾函數f
上的結果,在這種情況下,f
將用於在默認服務簽名 key 下為 SavedModel 生成簽名,- 一個字典,它將簽名鍵映射到帶有輸入簽名或具體函數的
tf.function
實例。這種字典的鍵可以是任意字符串,但通常來自tf.saved_model.signature_constants
模塊。
- 帶有指定輸入簽名的
-
options
tf.saved_model.SaveOptions
對象用於配置保存選項。
拋出
-
ValueError
如果obj
不可追蹤。
obj
必須繼承自 Trackable 類。
示例用法:
class Adder(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def add(self, x):
return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')
然後,生成的 SavedModel 可以使用名為 "x" 的輸入進行服務,這是一個 dtype 為 float32 的標量。
簽名
簽名定義計算的輸入和輸出類型。可選的 save signatures
參數控製 obj
中的哪些方法可用於消耗 SavedModel
的程序,例如服務 API。 Python 函數可以用 @tf.function(input_signature=...)
裝飾並作為簽名直接傳遞,或者在用 @tf.function
裝飾的方法上懶惰地調用 get_concrete_function
。
例子:
class Adder(tf.Module):
@tf.function
def add(self, x):
return x + x
model = Adder()
tf.saved_model.save(
model, '/tmp/adder',signatures=model.add.get_concrete_function(
tf.TensorSpec([], tf.float32)))
如果 @tf.function
沒有輸入簽名並且沒有在該方法上調用 get_concrete_function
,則該函數將無法在恢複的 SavedModel 中直接調用。
例子:
class Adder(tf.Module):
@tf.function
def add(self, x):
return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')
restored = tf.saved_model.load('/tmp/adder')
restored.add(1.)
Traceback (most recent call last):
ValueError:Found zero restored functions for caller function.
如果省略signatures
參數,則將在obj
中搜索@tf.function
修飾的方法。如果恰好找到一個跟蹤的@tf.function
,則該方法將用作 SavedModel 的默認簽名。否則,任何附加到 obj
的 @tf.function
或其依賴項都將被導出以用於 tf.saved_model.load
。
在導出的 SavedModel 中調用簽名時,Tensor
參數由名稱標識。默認情況下,這些名稱將來自 Python 函數的參數名稱。它們可以通過在相應的 tf.TensorSpec
對象中指定 name=...
參數來覆蓋。如果多個 Tensor
通過單個參數傳遞給 Python 函數,則需要顯式命名。
用作 signatures
的函數的輸出必須是平麵列表,在這種情況下輸出將被編號,或者是映射字符串鍵到 Tensor
的字典,在這種情況下,鍵將用於命名輸出。
簽名在 tf.saved_model.load
作為 .signatures
屬性返回的對象中可用。這是一個保留屬性:具有自定義 .signatures
屬性的對象上的 tf.saved_model.save
將引發異常。
_將 tf.savedmodel.save
與 Keras 模型一起使用
雖然 Keras 有自己的保存和加載 API,但此函數可用於導出 Keras 模型。例如,使用指定的簽名導出:
class Adder(tf.keras.Model):
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def concat(self, x):
return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')
從沒有固定簽名的函數導出:
class Adder(tf.keras.Model):
@tf.function
def concat(self, x):
return x + x
model = Adder()
tf.saved_model.save(
model, '/tmp/adder',
signatures=model.concat.get_concrete_function(
tf.TensorSpec(shape=[], dtype=tf.string, name="string_input")))
從輸入和輸出構造的 tf.keras.Model
實例已經具有簽名,因此不需要 @tf.function
裝飾器或 signatures
參數。如果兩者均未指定,則導出模型的正向傳遞。
x = tf.keras.layers.Input((4,), name="x")
y = tf.keras.layers.Dense(5, name="out")(x)
model = tf.keras.Model(x, y)
tf.saved_model.save(model, '/tmp/saved_model/')
導出的 SavedModel 采用形狀為 [None, 4] 的 "x" 並返回形狀為 [None, 5] 的 "out"
變量和檢查點
必須通過將變量分配給被跟蹤對象的屬性或直接分配給obj
的屬性來跟蹤變量。 TensorFlow 對象(例如來自 tf.keras.layers
的層,來自 tf.train
的優化器)自動跟蹤它們的變量。這與tf.train.Checkpoint
使用的跟蹤方案相同,導出的Checkpoint
對象可以通過將tf.train.Checkpoint.restore
指向SavedModel 的"variables/" 子目錄來恢複為訓練檢查點。
tf.function
不使用函數體外部的hard-code 設備注釋,而不是使用調用上下文的設備。這意味著例如導出在 GPU 上運行的模型並在 CPU 上提供服務通常可以工作,但有一些異常:
- 函數體內的
tf.device
注解在導出的模型中將是hard-coded;不鼓勵使用這種類型的注釋。 - Device-specific 操作,例如名稱中帶有 "cuDNN" 或帶有 device-specific 布局可能會導致問題。
- 對於
ConcreteFunctions
,主動分布策略將導致設備放置在函數中為 hard-coded。
導出的 SavedModelstf.saved_model.save 剝離 default-valued 屬性自動,當 SavedModel 的消費者運行的 TensorFlow 版本比生產者更舊時,它會消除一個不兼容的來源。然而,還有其他無法自動處理的不兼容來源,例如當導出的模型包含消費者沒有定義的操作時。
eager模式兼容性
圖形構建時沒有得到很好的支持。從 TensorFlow 1.x 開始,tf.compat.v1.enable_eager_execution()
應該首先運行。從 TensorFlow 1.x 構建圖形時,在循環中調用 tf.saved_model.save 將在每次迭代時向默認圖形添加新的保存操作。
不能從函數體內調用。
相關用法
- Python tf.saved_model.load用法及代碼示例
- Python tf.saved_model.Asset用法及代碼示例
- Python tf.saved_model.SaveOptions用法及代碼示例
- Python tf.saved_model.experimental.TrackableResource用法及代碼示例
- Python tf.summary.scalar用法及代碼示例
- Python tf.strings.substr用法及代碼示例
- Python tf.strings.reduce_join用法及代碼示例
- Python tf.sparse.cross用法及代碼示例
- Python tf.sparse.mask用法及代碼示例
- Python tf.strings.regex_full_match用法及代碼示例
- Python tf.sparse.split用法及代碼示例
- Python tf.strings.regex_replace用法及代碼示例
- Python tf.signal.overlap_and_add用法及代碼示例
- Python tf.strings.length用法及代碼示例
- Python tf.strided_slice用法及代碼示例
- Python tf.sparse.to_dense用法及代碼示例
- Python tf.strings.bytes_split用法及代碼示例
- Python tf.summary.text用法及代碼示例
- Python tf.shape用法及代碼示例
- Python tf.sparse.expand_dims用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.saved_model.save。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。