當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch LSTMCell用法及代碼示例


本文簡要介紹python語言中 torch.nn.LSTMCell 的用法。

用法:

class torch.nn.LSTMCell(input_size, hidden_size, bias=True, device=None, dtype=None)

參數

  • input_size-輸入x中的預期特征數

  • hidden_size-隱藏狀態的特征數h

  • bias-如果 False ,則該層不使用偏置權重 b_ihb_hh 。默認值:True

變量

  • ~LSTMCell.weight_ih(torch.Tensor) -可學習的 input-hidden 權重,形狀為 (4*hidden_size, input_size)

  • ~LSTMCell.weight_hh(torch.Tensor) -可學習的 hidden-hidden 權重,形狀為 (4*hidden_size, hidden_size)

  • ~LSTMCell.bias_ih-形狀為 (4*hidden_size) 的可學習 input-hidden 偏差

  • ~LSTMCell.bias_hh-形狀為 (4*hidden_size) 的可學習 hidden-hidden 偏差

長短期內存記憶(LSTM)單元

其中 是 sigmoid 函數, 是 Hadamard 積。

輸入:輸入,(h_0,c_0)
  • 輸入形狀的(batch, input_size):包含輸入特征的張量

  • h_0形狀的(batch, hidden_size):包含批次中每個元素的初始隱藏狀態的張量。

  • c_0形狀的(batch, hidden_size):包含批次中每個元素的初始單元狀態的張量。

    如果(h_0, c_0)不提供,兩者h_0c_0默認為零。

輸出:(h_1,c_1)
  • h_1形狀的(batch, hidden_size):包含批次中每個元素的下一個隱藏狀態的張量

  • c_1形狀的(batch, hidden_size): 包含批次中每個元素的下一個單元狀態的張量

注意

所有的權重和偏差都是從 初始化的,其中

例子:

>>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
>>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
>>> hx = torch.randn(3, 20) # (batch, hidden_size)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(input.size()[0]):
        hx, cx = rnn(input[i], (hx, cx))
        output.append(hx)
>>> output = torch.stack(output, dim=0)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.LSTMCell。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。