当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch pad_packed_sequence用法及代码示例


本文简要介绍python语言中 torch.nn.utils.rnn.pad_packed_sequence 的用法。

用法:

torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)

参数

  • sequence(PackedSequence) -批量填充

  • batch_first(bool,可选的) -如果 True ,输出将采用 B x T x * 格式。

  • padding_value(float,可选的) -填充元素的值。

  • total_length(int,可选的) -如果不是 None ,输出将被填充到长度为 total_length 。如果 total_length 小于 sequence 中的最大序列长度,此方法将抛出 ValueError

返回

包含填充序列的张量元组,以及包含批次中每个序列长度列表的张量。当批次传递给 pack_padded_sequencepack_sequence 时,批次元素将按照最初的顺序重新排序。

填充一组打包的可变长度序列。

这是 pack_padded_sequence() 的逆运算。

返回的张量数据大小为 T x B x * ,其中 T 是最长序列的长度,B 是批量大小。如果batch_first 为True,则数据将转置为B x T x * 格式。

示例

>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1,2,0], [3,0,0], [4,5,6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
               sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
        [3, 0, 0],
        [4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])

注意

total_length 可用于在 Module 中实现 pack sequence -> recurrent network -> unpack sequence 模式,该 DataParallel 包装在 DataParallel 中。有关详细信息,请参阅此常见问题解答部分。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.utils.rnn.pad_packed_sequence。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。