当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch pack_sequence用法及代码示例


本文简要介绍python语言中 torch.nn.utils.rnn.pack_sequence 的用法。

用法:

torch.nn.utils.rnn.pack_sequence(sequences, enforce_sorted=True)

参数

  • sequences(list[Tensor]) -长度递减的序列列表。

  • enforce_sorted(bool,可选的) -如果 True ,检查输入是否包含按长度降序排序的序列。如果是 False ,则不检查此条件。默认值:True

返回

PackedSequence 对象

打包可变长度张量列表

sequences 应该是大小为 L x * 的张量列表,其中 L 是序列的长度,而 * 是任意数量的尾随维度,包括零。

对于未排序的序列,请使用 enforce_sorted = False 。如果 enforce_sortedTrue ,则应按长度递减的顺序对序列进行排序。 enforce_sorted = True 仅用于 ONNX 导出。

示例

>>> from torch.nn.utils.rnn import pack_sequence
>>> a = torch.tensor([1,2,3])
>>> b = torch.tensor([4,5])
>>> c = torch.tensor([6])
>>> pack_sequence([a, b, c])
PackedSequence(data=tensor([ 1,  4,  6,  2,  5,  3]), batch_sizes=tensor([ 3,  2,  1]))

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.utils.rnn.pack_sequence。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。