當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch pack_sequence用法及代碼示例


本文簡要介紹python語言中 torch.nn.utils.rnn.pack_sequence 的用法。

用法:

torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted=True)

參數

  • sequences(list[Tensor]) -長度遞減的序列列表。

  • enforce_sorted(bool,可選的) -如果 True ,檢查輸入是否包含按長度降序排序的序列。如果是 False ,則不檢查此條件。默認值:True

返回

PackedSequence 對象

打包可變長度張量列表

sequences 應該是大小為 L x * 的張量列表,其中 L 是序列的長度,而 * 是任意數量的尾隨維度,包括零。

對於未排序的序列,請使用 enforce_sorted = False 。如果 enforce_sortedTrue ,則應按長度遞減的順序對序列進行排序。 enforce_sorted = True 僅用於 ONNX 導出。

示例

>>> from torch.nn.utils.rnn import pack_sequence
>>> a = torch.tensor([1,2,3])
>>> b = torch.tensor([4,5])
>>> c = torch.tensor([6])
>>> pack_sequence([a, b, c])
PackedSequence(data=tensor([ 1,  4,  6,  2,  5,  3]), batch_sizes=tensor([ 3,  2,  1]))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.utils.rnn.pack_sequence。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。