當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.saved_model.save用法及代碼示例


導出 tf.Module(和子類)objSavedModel 格式.

用法

tf.saved_model.save(
    obj, export_dir, signatures=None, options=None
)

參數

  • obj 要導出的可跟蹤對象(例如 tf.Module 或 tf.train.Checkpoint)。
  • export_dir 用於編寫 SavedModel 的目錄。
  • signatures 可選,三種類型之一:
    • 帶有指定輸入簽名的tf.function,它將使用默認的服務簽名 key ,
    • f.get_concrete_function@tf.function 裝飾函數 f 上的結果,在這種情況下,f 將用於在默認服務簽名 key 下為 SavedModel 生成簽名,
    • 一個字典,它將簽名鍵映射到帶有輸入簽名或具體函數的 tf.function 實例。這種字典的鍵可以是任意字符串,但通常來自tf.saved_model.signature_constants 模塊。
  • options tf.saved_model.SaveOptions 對象用於配置保存選項。

拋出

  • ValueError 如果obj 不可追蹤。

obj 必須繼承自 Trackable 類。

示例用法:

class Adder(tf.Module):
  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
  def add(self, x):
    return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')

然後,生成的 SavedModel 可以使用名為 "x" 的輸入進行服務,這是一個 dtype 為 float32 的標量。

簽名

簽名定義計算的輸入和輸出類型。可選的 save signatures 參數控製 obj 中的哪些方法可用於消耗 SavedModel 的程序,例如服務 API。 Python 函數可以用 @tf.function(input_signature=...) 裝飾並作為簽名直接傳遞,或者在用 @tf.function 裝飾的方法上懶惰地調用 get_concrete_function

例子:

class Adder(tf.Module):
  @tf.function
  def add(self, x):
    return x + x
model = Adder()
tf.saved_model.save(
  model, '/tmp/adder',signatures=model.add.get_concrete_function(
    tf.TensorSpec([], tf.float32)))

如果 @tf.function 沒有輸入簽名並且沒有在該方法上調用 get_concrete_function,則該函數將無法在恢複的 SavedModel 中直接調用。

例子:

class Adder(tf.Module):
  @tf.function
  def add(self, x):
    return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')
restored = tf.saved_model.load('/tmp/adder')
restored.add(1.)
Traceback (most recent call last):

ValueError:Found zero restored functions for caller function.

如果省略signatures 參數,則將在obj 中搜索@tf.function 修飾的方法。如果恰好找到一個跟蹤的@tf.function,則該方法將用作 SavedModel 的默認簽名。否則,任何附加到 obj@tf.function 或其依賴項都將被導出以用於 tf.saved_model.load

在導出的 SavedModel 中調用簽名時,Tensor 參數由名稱標識。默認情況下,這些名稱將來自 Python 函數的參數名稱。它們可以通過在相應的 tf.TensorSpec 對象中指定 name=... 參數來覆蓋。如果多個 Tensor 通過單個參數傳遞給 Python 函數,則需要顯式命名。

用作 signatures 的函數的輸出必須是平麵列表,在這種情況下輸出將被編號,或者是映射字符串鍵到 Tensor 的字典,在這種情況下,鍵將用於命名輸出。

簽名在 tf.saved_model.load 作為 .signatures 屬性返回的對象中可用。這是一個保留屬性:具有自定義 .signatures 屬性的對象上的 tf.saved_model.save 將引發異常。

_將 tf.savedmodel.save 與 Keras 模型一起使用

雖然 Keras 有自己的保存和加載 API,但此函數可用於導出 Keras 模型。例如,使用指定的簽名導出:

class Adder(tf.keras.Model):
  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def concat(self, x):
     return x + x
model = Adder()
tf.saved_model.save(model, '/tmp/adder')

從沒有固定簽名的函數導出:

class Adder(tf.keras.Model):
  @tf.function
  def concat(self, x):
     return x + x
model = Adder()
tf.saved_model.save(
  model, '/tmp/adder',
  signatures=model.concat.get_concrete_function(
    tf.TensorSpec(shape=[], dtype=tf.string, name="string_input")))

從輸入和輸出構造的 tf.keras.Model 實例已經具有簽名,因此不需要 @tf.function 裝飾器或 signatures 參數。如果兩者均未指定,則導出模型的正向傳遞。

x = tf.keras.layers.Input((4,), name="x")
y = tf.keras.layers.Dense(5, name="out")(x)
model = tf.keras.Model(x, y)
tf.saved_model.save(model, '/tmp/saved_model/')

導出的 SavedModel 采用形狀為 [None, 4] 的 "x" 並返回形狀為 [None, 5] 的 "out"

變量和檢查點

必須通過將變量分配給被跟蹤對象的屬性或直接分配給obj 的屬性來跟蹤變量。 TensorFlow 對象(例如來自 tf.keras.layers 的層,來自 tf.train 的優化器)自動跟蹤它們的變量。這與tf.train.Checkpoint 使用的跟蹤方案相同,導出的Checkpoint 對象可以通過將tf.train.Checkpoint.restore 指向SavedModel 的"variables/" 子目錄來恢複為訓練檢查點。

tf.function 不使用函數體外部的hard-code 設備注釋,而不是使用調用上下文的設備。這意味著例如導出在 GPU 上運行的模型並在 CPU 上提供服務通常可以工作,但有一些異常:

  • 函數體內的tf.device注解在導出的模型中將是hard-coded;不鼓勵使用這種類型的注釋。
  • Device-specific 操作,例如名稱中帶有 "cuDNN" 或帶有 device-specific 布局可能會導致問題。
  • 對於 ConcreteFunctions ,主動分布策略將導致設備放置在函數中為 hard-coded。

導出的 SavedModelstf.saved_model.save 剝離 default-valued 屬性自動,當 SavedModel 的消費者運行的 TensorFlow 版本比生產者更舊時,它會消除一個不兼容的來源。然而,還有其他無法自動處理的不兼容來源,例如當導出的模型包含消費者沒有定義的操作時。

eager模式兼容性

圖形構建時沒有得到很好的支持。從 TensorFlow 1.x 開始,tf.compat.v1.enable_eager_execution() 應該首先運行。從 TensorFlow 1.x 構建圖形時,在循環中調用 tf.saved_model.save 將在每次迭代時向默認圖形添加新的保存操作。

不能從函數體內調用。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.saved_model.save。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。