當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch RNNCell用法及代碼示例


本文簡要介紹python語言中 torch.nn.RNNCell 的用法。

用法:

class torch.nn.RNNCell(input_size, hidden_size, bias=True, nonlinearity='tanh', device=None, dtype=None)

參數

  • input_size-輸入x中的預期特征數

  • hidden_size-隱藏狀態的特征數h

  • bias-如果 False ,則該層不使用偏置權重 b_ihb_hh 。默認值:True

  • nonlinearity-使用的非線性。可以是 'tanh''relu' 。默認值:'tanh'

變量

  • ~RNNCell.weight_ih(torch.Tensor) -可學習的 input-hidden 權重,形狀為 (hidden_size, input_size)

  • ~RNNCell.weight_hh(torch.Tensor) -可學習的 hidden-hidden 權重,形狀為 (hidden_size, hidden_size)

  • ~RNNCell.bias_ih-形狀為 (hidden_size) 的可學習 input-hidden 偏差

  • ~RNNCell.bias_hh-形狀為 (hidden_size) 的可學習 hidden-hidden 偏差

具有 tanh 或 ReLU 非線性的 Elman RNN 單元。

如果 nonlinearity‘relu’ ,則使用 ReLU 代替 tanh。

輸入:輸入,隱藏
  • 輸入形狀的(batch, input_size):包含輸入特征的張量

  • 形狀的(batch, hidden_size):包含批次中每個元素的初始隱藏狀態的張量。如果未提供,則默認為零。

輸出:h'
  • H'形狀的(batch, hidden_size):包含批次中每個元素的下一個隱藏狀態的張量

形狀:
  • 輸入1: 張量包含輸入特征,其中 = input_size

  • 輸入 2: 張量,包含批次中每個元素的初始隱藏狀態,其中 = hidden_size 如果未提供,則默認為零。

  • 輸出: 張量,包含批次中每個元素的下一個隱藏狀態

注意

所有的權重和偏差都是從 初始化的,其中

例子:

>>> rnn = nn.RNNCell(10, 20)
>>> input = torch.randn(6, 3, 10)
>>> hx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
        hx = rnn(input[i], hx)
        output.append(hx)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.RNNCell。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。