当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch pad_sequence用法及代码示例


本文简要介绍python语言中 torch.nn.utils.rnn.pad_sequence 的用法。

用法:

torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0.0)

参数

  • sequences(list[Tensor]) -可变长度序列列表。

  • batch_first(bool,可选的) -如果为真,输出将在 B x T x * 中,否则在 T x B x *

  • padding_value(float,可选的) -填充元素的值。默认值:0。

返回

如果 batch_firstFalse ,则张量大小为 T x B x * 。大小为 B x T x * 的张量,否则

padding_value填充可变长度张量列表

pad_sequence 沿新维度堆叠张量列表,并将它们填充到相等的长度。例如,如果输入是大小为 L x * 的序列列表,并且如果 batch_first 为 False,否则为 T x B x *

B 是批量大小。它等于 sequences 中的元素数。 T 是最长序列的长度。 L 是序列的长度。 * 是任意数量的尾随维度,包括无。

示例

>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])

注意

此函数返回大小为 T x B x *B x T x * 的张量,其中 T 是最长序列的长度。该函数假设序列中所有张量的尾随维度和类型相同。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.utils.rnn.pad_sequence。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。