本文简要介绍python语言中 torch.einsum
的用法。
用法:
torch.einsum(equation, *operands) → Tensor
equation(string) -爱因斯坦求和的下标。
operands(List[Tensor]) -计算爱因斯坦和的张量。
将输入
operands
的元素的乘积与使用基于 Einstein 求和约定的符号指定的维度相加。Einsum 允许计算许多常见的多维线性代数数组运算,通过以基于爱因斯坦求和约定的速记格式表示它们,由
equation
给出。下面说明了这种格式的细节,但总体思路是用一些下标来标记输入operands
的每个维度,并定义哪些下标是输出的一部分。然后,通过将operands
的元素的乘积沿下标不属于输出的一部分的维度求和来计算输出。例如,矩阵乘法可以使用 einsum 计算为torch.einsum(“ij,jk->ik”, A, B)
。这里,j 是求和下标,i 和 k 是输出下标(有关原因的更多详细信息,请参阅下面的部分)。方程:
equation
字符串以与维度相同的顺序为输入operands
的每个维度指定下标([a-zA-Z]
中的字母),用逗号(',')分隔每个操作数的下标,例如‘ij,jk’
为两个二维操作数指定下标。标有相同下标的维度必须是可广播的,即它们的大小必须匹配或为1
。例外情况是,如果为相同的输入操作数重复下标,在这种情况下,此操作数的标有此下标的维度必须在大小上匹配,并且操作数将被其沿这些维度的对角线替换。在equation
中恰好出现一次的下标将成为输出的一部分,按字母升序排序。输出是通过将输入operands
元素相乘,其尺寸基于下标对齐,然后将其下标不属于输出的尺寸相加来计算的。可选地,输出下标可以通过在等式末尾添加一个箭头 ('->') 来明确定义,然后是输出的下标。例如,以下等式计算矩阵乘法的转置:“ij,jk->ki”。对于某些输入操作数,输出下标必须至少出现一次,而对于输出则至多出现一次。
省略号(‘…’)可以用来代替下标来广播省略号所覆盖的维度。每个输入操作数最多可以包含一个省略号,该省略号将覆盖下标未覆盖的维度,例如对于 5 个维度的输入操作数,方程
‘ab…c’
中的省略号涵盖第三个和第四个维度。省略号不需要在operands
中覆盖相同数量的维度,但省略号的 ‘shape’(它们覆盖的维度的大小)必须一起广播。如果未使用箭头(“->”)符号显式定义输出,则省略号将首先出现在输出中(最左侧的维度),位于输入操作数仅出现一次的下标标签之前。例如以下等式实现批量矩阵乘法‘…ij,…jk’
。最后几点说明:等式可能包含不同元素(下标、省略号、箭头和逗号)之间的空格,但
‘…’
之类的内容无效。空字符串‘’
对标量操作数有效。注意
torch.einsum
处理省略号(“...”)与 NumPy 的不同之处在于它允许对省略号覆盖的维度求和,也就是说,省略号不需要成为输出的一部分。注意
此函数不会优化给定的表达式,因此相同计算的不同公式可能会运行得更快或消耗更少的内存。 opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) 等项目可以为您优化公式。
注意
从 PyTorch 1.10 开始,
torch.einsum()
还支持子列表格式(请参见下面的示例)。在此格式中,每个操作数的下标由子列表指定,子列表是 [0, 52) 范围内的整数列表。这些子列表跟随它们的操作数,并且额外的子列表可以出现在输入的末尾以指定输出的下标。torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])
。 Python 的Ellipsis
对象可以在子列表中提供,以启用广播,如上面的公式部分所述。例子:
# trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) # batch matrix multiplication >>> As = torch.randn(3,2,5) >>> Bs = torch.randn(3,5,4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3,5,4) >>> l = torch.randn(2,5) >>> r = torch.randn(2,4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])
参数:
相关用法
- Python PyTorch eigvals用法及代码示例
- Python PyTorch eigvalsh用法及代码示例
- Python PyTorch eig用法及代码示例
- Python PyTorch eigh用法及代码示例
- Python PyTorch enable_grad用法及代码示例
- Python PyTorch equal用法及代码示例
- Python PyTorch eq用法及代码示例
- Python PyTorch erfc用法及代码示例
- Python PyTorch exp用法及代码示例
- Python PyTorch empty_like用法及代码示例
- Python PyTorch expires用法及代码示例
- Python PyTorch effect_names用法及代码示例
- Python PyTorch entr用法及代码示例
- Python PyTorch embedding用法及代码示例
- Python PyTorch empty_strided用法及代码示例
- Python PyTorch emit_nvtx用法及代码示例
- Python PyTorch expm1用法及代码示例
- Python PyTorch export用法及代码示例
- Python PyTorch exp2用法及代码示例
- Python PyTorch embedding_bag用法及代码示例
- Python PyTorch erfinv用法及代码示例
- Python PyTorch extract_archive用法及代码示例
- Python PyTorch empty用法及代码示例
- Python PyTorch eye用法及代码示例
- Python PyTorch eye_用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.einsum。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。