当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.keras.utils.get_or_create_layer用法及代码示例


使用此方法以 shim-decorated 方法跟踪嵌套的 keras 模型。

用法

tf.compat.v1.keras.utils.get_or_create_layer(
    name, create_layer_method
)

参数

  • name 为嵌套层提供跟踪的名称。
  • create_layer_method 一个不带参数并返回嵌套层的 Callable。

返回

  • 创建的图层。

此方法可在由 track_tf1_style_variables shim 修饰的 tf.keras.Layer 方法中使用,以另外跟踪在同一方法中创建的内部 keras 模型对象。内部模型的变量和损失可以通过外部模型的variableslosses 属性访问。

这允许使用 TF2 行为跟踪内部 keras 模型,而对现有 TF1 样式代码的更改最少。

例子:

class NestedLayer(tf.keras.layers.Layer):

  def __init__(self, units, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.units = units

  def build_model(self):
    inp = tf.keras.Input(shape=(5, 5))
    dense_layer = tf.keras.layers.Dense(
        10, name="dense", kernel_regularizer="l2",
        kernel_initializer=tf.compat.v1.ones_initializer())
    model = tf.keras.Model(inputs=inp, outputs=dense_layer(inp))
    return model

  @tf.compat.v1.keras.utils.track_tf1_style_variables
  def call(self, inputs):
    model = tf.compat.v1.keras.utils.get_or_create_layer(
        "dense_model", self.build_model)
    return model(inputs)

内部模型创建应仅限于其自己的zero-arg 函数,该函数应传递给此方法。在 TF1 中,此方法将立即创建并返回所需的模型,无需任何跟踪。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.keras.utils.get_or_create_layer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。