当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.layers.AbstractRNNCell用法及代码示例


表示 RNN 单元的抽象对象。

继承自:LayerModule

用法

tf.keras.layers.AbstractRNNCell(
    trainable=True, name=None, dtype=None, dynamic=False, **kwargs
)

属性

  • output_size 整数或 TensorShape:此单元产生的输出大小。
  • state_size 此单元格使用的状态大小。

    它可以由整数、TensorShape 或整数或 TensorShapes 的元组表示。

有关 RNN API 使用的详细信息,请参阅 Keras RNN API 指南。

这是实现具有自定义行为的 RNN 单元的基类。

每个 RNNCell 必须具有以下属性,并使用签名 call 实现 (output, next_state) = call(input, state)

例子:

class MinimalRNNCell(AbstractRNNCell):

    def __init__(self, units, **kwargs):
      self.units = units
      super(MinimalRNNCell, self).__init__(**kwargs)

    @property
    def state_size(self):
      return self.units

    def build(self, input_shape):
      self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                    initializer='uniform',
                                    name='kernel')
      self.recurrent_kernel = self.add_weight(
          shape=(self.units, self.units),
          initializer='uniform',
          name='recurrent_kernel')
      self.built = True

    def call(self, inputs, states):
      prev_output = states[0]
      h = backend.dot(inputs, self.kernel)
      output = h + backend.dot(prev_output, self.recurrent_kernel)
      return output, output

细胞的这种定义不同于文献中使用的定义。在文献中,'cell' 指的是具有单个标量输出的对象。该定义是指此类单元的水平阵列。

在最抽象的设置中,RNN 单元是任何具有状态并执行接受输入矩阵的操作的任何东西。此操作产生一个带有self.output_size 列的输出矩阵。如果 self.state_size 是整数,则此操作还会生成一个新的状态矩阵,其中包含 self.state_size 列。如果 self.state_size 是 TensorShape 对象的(可能是嵌套的元组),那么它应该为 self.batch_size 中的每个 s 返回具有形状 [batch_size].concatenate(s) 的张量的匹配结构。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.layers.AbstractRNNCell。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。