當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch pad_packed_sequence用法及代碼示例


本文簡要介紹python語言中 torch.nn.utils.rnn.pad_packed_sequence 的用法。

用法:

torch.nn.utils.rnn.pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)

參數

  • sequence(PackedSequence) -批量填充

  • batch_first(bool,可選的) -如果 True ,輸出將采用 B x T x * 格式。

  • padding_value(float,可選的) -填充元素的值。

  • total_length(int,可選的) -如果不是 None ,輸出將被填充到長度為 total_length 。如果 total_length 小於 sequence 中的最大序列長度,此方法將拋出 ValueError

返回

包含填充序列的張量元組,以及包含批次中每個序列長度列表的張量。當批次傳遞給 pack_padded_sequencepack_sequence 時,批次元素將按照最初的順序重新排序。

填充一組打包的可變長度序列。

這是 pack_padded_sequence() 的逆運算。

返回的張量數據大小為 T x B x * ,其中 T 是最長序列的長度,B 是批量大小。如果batch_first 為True,則數據將轉置為B x T x * 格式。

示例

>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1,2,0], [3,0,0], [4,5,6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
               sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
        [3, 0, 0],
        [4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])

注意

total_length 可用於在 Module 中實現 pack sequence -> recurrent network -> unpack sequence 模式,該 DataParallel 包裝在 DataParallel 中。有關詳細信息,請參閱此常見問題解答部分。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.utils.rnn.pad_packed_sequence。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。