表示 RNN 单元的抽象对象。
用法
tf.keras.layers.AbstractRNNCell(
trainable=True, name=None, dtype=None, dynamic=False, **kwargs
)
属性
-
output_size
整数或 TensorShape:此单元产生的输出大小。 -
state_size
此单元格使用的状态大小。它可以由整数、TensorShape 或整数或 TensorShapes 的元组表示。
有关 RNN API 使用的详细信息,请参阅 Keras RNN API 指南。
这是实现具有自定义行为的 RNN 单元的基类。
每个 RNNCell
必须具有以下属性,并使用签名 call
实现 (output, next_state) = call(input, state)
。
例子:
class MinimalRNNCell(AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(MinimalRNNCell, self).__init__(**kwargs)
@property
def state_size(self):
return self.units
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = backend.dot(inputs, self.kernel)
output = h + backend.dot(prev_output, self.recurrent_kernel)
return output, output
细胞的这种定义不同于文献中使用的定义。在文献中,'cell' 指的是具有单个标量输出的对象。该定义是指此类单元的水平阵列。
在最抽象的设置中,RNN 单元是任何具有状态并执行接受输入矩阵的操作的任何东西。此操作产生一个带有self.output_size
列的输出矩阵。如果 self.state_size
是整数,则此操作还会生成一个新的状态矩阵,其中包含 self.state_size
列。如果 self.state_size
是 TensorShape 对象的(可能是嵌套的元组),那么它应该为 self.batch_size
中的每个 s
返回具有形状 [batch_size].concatenate(s)
的张量的匹配结构。
相关用法
- Python tf.keras.layers.Activation用法及代码示例
- Python tf.keras.layers.AveragePooling3D用法及代码示例
- Python tf.keras.layers.Attention用法及代码示例
- Python tf.keras.layers.AveragePooling2D用法及代码示例
- Python tf.keras.layers.Add用法及代码示例
- Python tf.keras.layers.Average用法及代码示例
- Python tf.keras.layers.AveragePooling1D用法及代码示例
- Python tf.keras.layers.AdditiveAttention用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.keras.layers.serialize用法及代码示例
- Python tf.keras.layers.Dropout用法及代码示例
- Python tf.keras.layers.maximum用法及代码示例
- Python tf.keras.layers.LayerNormalization用法及代码示例
- Python tf.keras.layers.Conv2D用法及代码示例
- Python tf.keras.layers.RepeatVector用法及代码示例
- Python tf.keras.layers.Multiply用法及代码示例
- Python tf.keras.layers.Conv1D用法及代码示例
- Python tf.keras.layers.experimental.preprocessing.PreprocessingLayer.adapt用法及代码示例
- Python tf.keras.layers.CategoryEncoding用法及代码示例
- Python tf.keras.layers.subtract用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.layers.AbstractRNNCell。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。