当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.layers.experimental.keras_style_scope用法及代码示例


使用Keras-style 变量管理。

用法

@tf_contextlib.contextmanager
tf.compat.v1.layers.experimental.keras_style_scope()

生成(Yield)

  • keras 图层样式范围。

在此范围内创建的所有 tf​​.layers 和 tf RNN 单元都使用Keras-style 变量管理。不允许使用 scope= 参数创建此类层,并且不允许使用 reuse=True。

此范围的目的是允许现有层的用户在不破坏现有函数的情况下缓慢过渡到 Keras 层 API。

其中一个示例是在将 TensorFlow 的 RNN 类与 Keras 模型或网络一起使用时。由于 Keras 模型没有正确设置变量范围,RNN 的用户可能会意外地在两个不同模型之间共享范围,或者得到关于已经存在的变量的错误。

例子:

class RNNModel(tf.keras.Model):

  def __init__(self, name):
    super(RNNModel, self).__init__(name=name)
    self.rnn = tf.compat.v1.nn.rnn_cell.MultiRNNCell(
      [tf.compat.v1.nn.rnn_cell.LSTMCell(64) for _ in range(2)])

  def call(self, input, state):
    return self.rnn(input, state)

model_1 = RNNModel("model_1")
model_2 = RNNModel("model_2")

# OK
output_1, next_state_1 = model_1(input, state)
# Raises an error about trying to create an already existing variable.
output_2, next_state_2 = model_2(input, state)

解决方案是将模型构造和执行包装在 keras-style 范围内:

with keras_style_scope():
  model_1 = RNNModel("model_1")
  model_2 = RNNModel("model_2")

  # model_1 and model_2 are guaranteed to create their own variables.
  output_1, next_state_1 = model_1(input, state)
  output_2, next_state_2 = model_2(input, state)

  assert len(model_1.weights) > 0
  assert len(model_2.weights) > 0
  assert(model_1.weights != model_2.weights)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.layers.experimental.keras_style_scope。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。