当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch einsum用法及代码示例


本文简要介绍python语言中 torch.einsum 的用法。

用法:

torch.einsum(equation, *operands) → Tensor

参数

  • equation(string) -爱因斯坦求和的下标。

  • operands(List[Tensor]) -计算爱因斯坦和的张量。

将输入 operands 的元素的乘积与使用基于 Einstein 求和约定的符号指定的维度相加。

Einsum 允许计算许多常见的多维线性代数数组运算,通过以基于爱因斯坦求和约定的速记格式表示它们,由 equation 给出。下面说明了这种格式的细节,但总体思路是用一些下标来标记输入 operands 的每个维度,并定义哪些下标是输出的一部分。然后,通过将 operands 的元素的乘积沿下标不属于输出的一部分的维度求和来计算输出。例如,矩阵乘法可以使用 einsum 计算为 torch.einsum(“ij,jk->ik”, A, B) 。这里,j 是求和下标,i 和 k 是输出下标(有关原因的更多详细信息,请参阅下面的部分)。

方程:

equation 字符串以与维度相同的顺序为输入 operands 的每个维度指定下标([a-zA-Z] 中的字母),用逗号(',')分隔每个操作数的下标,例如‘ij,jk’ 为两个二维操作数指定下标。标有相同下标的维度必须是可广播的,即它们的大小必须匹配或为 1 。例外情况是,如果为相同的输入操作数重复下标,在这种情况下,此操作数的标有此下标的维度必须在大小上匹配,并且操作数将被其沿这些维度的对角线替换。在equation 中恰好出现一次的下标将成为输出的一部分,按字母升序排序。输出是通过将输入 operands 元素相乘,其尺寸基于下标对齐,然后将其下标不属于输出的尺寸相加来计算的。

可选地,输出下标可以通过在等式末尾添加一个箭头 ('->') 来明确定义,然后是输出的下标。例如,以下等式计算矩阵乘法的转置:“ij,jk->ki”。对于某些输入操作数,输出下标必须至少出现一次,而对于输出则至多出现一次。

省略号(‘…’)可以用来代替下标来广播省略号所覆盖的维度。每个输入操作数最多可以包含一个省略号,该省略号将覆盖下标未覆盖的维度,例如对于 5 个维度的输入操作数,方程 ‘ab…c’ 中的省略号涵盖第三个和第四个维度。省略号不需要在 operands 中覆盖相同数量的维度,但省略号的 ‘shape’(它们覆盖的维度的大小)必须一起广播。如果未使用箭头(“->”)符号显式定义输出,则省略号将首先出现在输出中(最左侧的维度),位于输入操作数仅出现一次的下标标签之前。例如以下等式实现批量矩阵乘法‘…ij,…jk’

最后几点说明:等式可能包含不同元素(下标、省略号、箭头和逗号)之间的空格,但 ‘…’ 之类的内容无效。空字符串 ‘’ 对标量操作数有效。

注意

torch.einsum 处理省略号(“...”)与 NumPy 的不同之处在于它允许对省略号覆盖的维度求和,也就是说,省略号不需要成为输出的一部分。

注意

此函数不会优化给定的表达式,因此相同计算的不同公式可能会运行得更快或消耗更少的内存。 opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) 等项目可以为您优化公式。

注意

从 PyTorch 1.10 开始,torch.einsum() 还支持子列表格式(请参见下面的示例)。在此格式中,每个操作数的下标由子列表指定,子列表是 [0, 52) 范围内的整数列表。这些子列表跟随它们的操作数,并且额外的子列表可以出现在输入的末尾以指定输出的下标。 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out]) 。 Python 的 Ellipsis 对象可以在子列表中提供,以启用广播,如上面的公式部分所述。

例子:

# trace
>>> torch.einsum('ii', torch.randn(4, 4))
tensor(-1.2104)

# diagonal
>>> torch.einsum('ii->i', torch.randn(4, 4))
tensor([-0.1034,  0.7952, -0.2433,  0.4545])

# outer product
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('i,j->ij', x, y)
tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
        [-0.3744,  0.9381,  1.2685, -1.6070],
        [ 0.7208, -1.8058, -2.4419,  3.0936],
        [ 0.1713, -0.4291, -0.5802,  0.7350],
        [ 0.5704, -1.4290, -1.9323,  2.4480]])

# batch matrix multiplication
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

# with sublist format and ellipsis
>>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

# batch permute
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2, 3, 5, 4])

# equivalent to torch.nn.functional.bilinear
>>> A = torch.randn(3,5,4)
>>> l = torch.randn(2,5)
>>> r = torch.randn(2,4)
>>> torch.einsum('bn,anm,bm->ba', l, A, r)
tensor([[-0.3430, -5.2405,  0.4494],
        [ 0.3311,  5.5201, -3.0356]])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.einsum。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。