本文简要介绍python语言中 torch.nn.GRU
的用法。
用法:
class torch.nn.GRU(*args, **kwargs)
input_size-输入
x
中的预期特征数hidden_size-隐藏状态的特征数
h
num_layers-循环层数。例如,设置
num_layers=2
意味着将两个 GRU 堆叠在一起形成一个stacked GRU
,第二个 GRU 接收第一个 GRU 的输出并计算最终结果。默认值:1bias-如果
False
,则该层不使用偏置权重b_ih
和b_hh
。默认值:True
batch_first-如果
True
,则输入和输出张量提供为(batch, seq, feature)
而不是(seq, batch, feature)
。请注意,这不适用于隐藏或单元状态。有关详细信息,请参阅下面的输入/输出部分。默认值:False
dropout-如果非零,则在除最后一层之外的每个 GRU 层的输出上引入
Dropout
层,丢弃概率等于dropout
。默认值:0bidirectional-如果
True
,则成为双向 GRU。默认值:False
~GRU.weight_ih_l[k]-
(3*hidden_size, input_size)
用于k = 0
。否则,形状为(3*hidden_size, num_directions * hidden_size)
层 (W_ir|W_iz|W_in) 的可学习 input-hidden 权重,形状为~GRU.weight_hh_l[k]-
(3*hidden_size, hidden_size)
层 (W_hr|W_hz|W_hn) 的可学习 hidden-hidden 权重,形状为~GRU.bias_ih_l[k]-
(3*hidden_size)
层 (b_ir|b_iz|b_in) 的可学习 input-hidden 偏差,形状为~GRU.bias_hh_l[k]-
(3*hidden_size)
层 (b_hr|b_hz|b_hn) 的可学习 hidden-hidden 偏差,形状为
将多层门控循环单元 (GRU) RNN 应用于输入序列。
对于输入序列中的每个元素,每一层计算以下函数:
其中
t
的隐藏状态, 是时间t
的输入; 是层在时间t-1
的隐藏状态或在时间0
的初始隐藏状态,而 、 、 分别是重置、更新和新门。 是 sigmoid 函数, 是 Hadamard 积。 是时间在多层 GRU 中,第
dropout
。 层 ( ) 的输入 是前一层的隐藏状态 乘以 dropout 其中每个 是一个伯努利随机变量这是 概率为- 输入:输入,h_0
input: 形状张量
当batch_first=False
或者 当batch_first=True
包含输入序列的特征。输入也可以是打包的可变长度序列。看torch.nn.utils.rnn.pack_padded_sequence()
或者torch.nn.utils.rnn.pack_sequence详情。h_0: 形状张量
包含批次中每个元素的初始隐藏状态。如果未提供,则默认为零。
其中:
- 输出:输出,h_n
output: 形状张量
当batch_first=False
或者 当batch_first=True
包含输出特征(h_t)
从 GRU 的最后一层,对于每个t
.如果一个torch.nn.utils.rnn.PackedSequence
已作为输入给出,输出也将是一个打包序列。h_n: 形状张量
包含批次中每个元素的最终隐藏状态。
注意
所有的权重和偏差都是从 初始化的,其中
注意
对于双向 GRU,向前和向后分别是方向 0 和 1。
batch_first=False
:output.view(seq_len, batch, num_directions, hidden_size)
时拆分输出层的示例。- 孤儿
注意
如果满足以下条件:1) cudnn 已启用,2) 输入数据在 GPU 上 3) 输入数据具有 dtype
torch.float16
4) 使用 V100 GPU,5) 输入数据不是PackedSequence
格式的持久化算法可以选择以提高性能。例子:
>>> rnn = nn.GRU(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> output, hn = rnn(input, h0)
参数:
变量:
相关用法
- Python PyTorch GRUCell用法及代码示例
- Python PyTorch Graph.eliminate_dead_code用法及代码示例
- Python PyTorch Generator.set_state用法及代码示例
- Python PyTorch GroupedPositionWeightedModule.named_parameters用法及代码示例
- Python PyTorch Graph.inserting_before用法及代码示例
- Python PyTorch GradScaler.unscale_用法及代码示例
- Python PyTorch Generator.seed用法及代码示例
- Python PyTorch GroupedPooledEmbeddingsLookup.named_buffers用法及代码示例
- Python PyTorch GLU用法及代码示例
- Python PyTorch Graph.inserting_after用法及代码示例
- Python PyTorch GroupNorm用法及代码示例
- Python PyTorch GDriveReader用法及代码示例
- Python PyTorch Gumbel用法及代码示例
- Python PyTorch Graph用法及代码示例
- Python PyTorch Generator.get_state用法及代码示例
- Python PyTorch GELU用法及代码示例
- Python PyTorch Generator.manual_seed用法及代码示例
- Python PyTorch Geometric用法及代码示例
- Python PyTorch GroupedPositionWeightedModule.named_buffers用法及代码示例
- Python PyTorch GroupedPooledEmbeddingsLookup.state_dict用法及代码示例
- Python PyTorch GriffinLim用法及代码示例
- Python PyTorch GroupedPooledEmbeddingsLookup.named_parameters用法及代码示例
- Python PyTorch Graph.node_copy用法及代码示例
- Python PyTorch GaussianNLLLoss用法及代码示例
- Python PyTorch Generator用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.GRU。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。