當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.layers.AbstractRNNCell用法及代碼示例


表示 RNN 單元的抽象對象。

繼承自:LayerModule

用法

tf.keras.layers.AbstractRNNCell(
    trainable=True, name=None, dtype=None, dynamic=False, **kwargs
)

屬性

  • output_size 整數或 TensorShape:此單元產生的輸出大小。
  • state_size 此單元格使用的狀態大小。

    它可以由整數、TensorShape 或整數或 TensorShapes 的元組表示。

有關 RNN API 使用的詳細信息,請參閱 Keras RNN API 指南。

這是實現具有自定義行為的 RNN 單元的基類。

每個 RNNCell 必須具有以下屬性,並使用簽名 call 實現 (output, next_state) = call(input, state)

例子:

class MinimalRNNCell(AbstractRNNCell):

    def __init__(self, units, **kwargs):
      self.units = units
      super(MinimalRNNCell, self).__init__(**kwargs)

    @property
    def state_size(self):
      return self.units

    def build(self, input_shape):
      self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                    initializer='uniform',
                                    name='kernel')
      self.recurrent_kernel = self.add_weight(
          shape=(self.units, self.units),
          initializer='uniform',
          name='recurrent_kernel')
      self.built = True

    def call(self, inputs, states):
      prev_output = states[0]
      h = backend.dot(inputs, self.kernel)
      output = h + backend.dot(prev_output, self.recurrent_kernel)
      return output, output

細胞的這種定義不同於文獻中使用的定義。在文獻中,'cell' 指的是具有單個標量輸出的對象。該定義是指此類單元的水平陣列。

在最抽象的設置中,RNN 單元是任何具有狀態並執行接受輸入矩陣的操作的任何東西。此操作產生一個帶有self.output_size 列的輸出矩陣。如果 self.state_size 是整數,則此操作還會生成一個新的狀態矩陣,其中包含 self.state_size 列。如果 self.state_size 是 TensorShape 對象的(可能是嵌套的元組),那麽它應該為 self.batch_size 中的每個 s 返回具有形狀 [batch_size].concatenate(s) 的張量的匹配結構。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.layers.AbstractRNNCell。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。