當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.keras.utils.get_or_create_layer用法及代碼示例


使用此方法以 shim-decorated 方法跟蹤嵌套的 keras 模型。

用法

tf.compat.v1.keras.utils.get_or_create_layer(
    name, create_layer_method
)

參數

  • name 為嵌套層提供跟蹤的名稱。
  • create_layer_method 一個不帶參數並返回嵌套層的 Callable。

返回

  • 創建的圖層。

此方法可在由 track_tf1_style_variables shim 修飾的 tf.keras.Layer 方法中使用,以另外跟蹤在同一方法中創建的內部 keras 模型對象。內部模型的變量和損失可以通過外部模型的variableslosses 屬性訪問。

這允許使用 TF2 行為跟蹤內部 keras 模型,而對現有 TF1 樣式代碼的更改最少。

例子:

class NestedLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  def build_model(self):
    inp = tf.keras.Input(shape=(5, 5))
    dense_layer = tf.keras.layers.Dense(
        10, name="dense", kernel_regularizer="l2",
        kernel_initializer=tf.compat.v1.ones_initializer())
    model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))
    return model

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    model = tf.compat.v1.keras.utils.get_or_create_layer(
        "dense_model", self.build_model)
    return model(inputs)

內部模型創建應僅限於其自己的zero-arg 函數,該函數應傳遞給此方法。在 TF1 中,此方法將立即創建並返回所需的模型,無需任何跟蹤。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.keras.utils.get_or_create_layer。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。