當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch pad_sequence用法及代碼示例


本文簡要介紹python語言中 torch.nn.utils.rnn.pad_sequence 的用法。

用法:

torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0.0)

參數

  • sequences(list[Tensor]) -可變長度序列列表。

  • batch_first(bool,可選的) -如果為真,輸出將在 B x T x * 中,否則在 T x B x *

  • padding_value(float,可選的) -填充元素的值。默認值:0。

返回

如果 batch_firstFalse ,則張量大小為 T x B x * 。大小為 B x T x * 的張量,否則

padding_value填充可變長度張量列表

pad_sequence 沿新維度堆疊張量列表,並將它們填充到相等的長度。例如,如果輸入是大小為 L x * 的序列列表,並且如果 batch_first 為 False,否則為 T x B x *

B 是批量大小。它等於 sequences 中的元素數。 T 是最長序列的長度。 L 是序列的長度。 * 是任意數量的尾隨維度,包括無。

示例

>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])

注意

此函數返回大小為 T x B x *B x T x * 的張量,其中 T 是最長序列的長度。該函數假設序列中所有張量的尾隨維度和類型相同。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.utils.rnn.pad_sequence。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。