本文簡要介紹python語言中 torch.einsum
的用法。
用法:
torch.einsum(equation, *operands) → Tensor
equation(string) -愛因斯坦求和的下標。
operands(List[Tensor]) -計算愛因斯坦和的張量。
將輸入
operands
的元素的乘積與使用基於 Einstein 求和約定的符號指定的維度相加。Einsum 允許計算許多常見的多維線性代數數組運算,通過以基於愛因斯坦求和約定的速記格式表示它們,由
equation
給出。下麵說明了這種格式的細節,但總體思路是用一些下標來標記輸入operands
的每個維度,並定義哪些下標是輸出的一部分。然後,通過將operands
的元素的乘積沿下標不屬於輸出的一部分的維度求和來計算輸出。例如,矩陣乘法可以使用 einsum 計算為torch.einsum(“ij,jk->ik”, A, B)
。這裏,j 是求和下標,i 和 k 是輸出下標(有關原因的更多詳細信息,請參閱下麵的部分)。方程:
equation
字符串以與維度相同的順序為輸入operands
的每個維度指定下標([a-zA-Z]
中的字母),用逗號(',')分隔每個操作數的下標,例如‘ij,jk’
為兩個二維操作數指定下標。標有相同下標的維度必須是可廣播的,即它們的大小必須匹配或為1
。例外情況是,如果為相同的輸入操作數重複下標,在這種情況下,此操作數的標有此下標的維度必須在大小上匹配,並且操作數將被其沿這些維度的對角線替換。在equation
中恰好出現一次的下標將成為輸出的一部分,按字母升序排序。輸出是通過將輸入operands
元素相乘,其尺寸基於下標對齊,然後將其下標不屬於輸出的尺寸相加來計算的。可選地,輸出下標可以通過在等式末尾添加一個箭頭 ('->') 來明確定義,然後是輸出的下標。例如,以下等式計算矩陣乘法的轉置:“ij,jk->ki”。對於某些輸入操作數,輸出下標必須至少出現一次,而對於輸出則至多出現一次。
省略號(‘…’)可以用來代替下標來廣播省略號所覆蓋的維度。每個輸入操作數最多可以包含一個省略號,該省略號將覆蓋下標未覆蓋的維度,例如對於 5 個維度的輸入操作數,方程
‘ab…c’
中的省略號涵蓋第三個和第四個維度。省略號不需要在operands
中覆蓋相同數量的維度,但省略號的 ‘shape’(它們覆蓋的維度的大小)必須一起廣播。如果未使用箭頭(“->”)符號顯式定義輸出,則省略號將首先出現在輸出中(最左側的維度),位於輸入操作數僅出現一次的下標標簽之前。例如以下等式實現批量矩陣乘法‘…ij,…jk’
。最後幾點說明:等式可能包含不同元素(下標、省略號、箭頭和逗號)之間的空格,但
‘…’
之類的內容無效。空字符串‘’
對標量操作數有效。注意
torch.einsum
處理省略號(“...”)與 NumPy 的不同之處在於它允許對省略號覆蓋的維度求和,也就是說,省略號不需要成為輸出的一部分。注意
此函數不會優化給定的表達式,因此相同計算的不同公式可能會運行得更快或消耗更少的內存。 opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) 等項目可以為您優化公式。
注意
從 PyTorch 1.10 開始,
torch.einsum()
還支持子列表格式(請參見下麵的示例)。在此格式中,每個操作數的下標由子列表指定,子列表是 [0, 52) 範圍內的整數列表。這些子列表跟隨它們的操作數,並且額外的子列表可以出現在輸入的末尾以指定輸出的下標。torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])
。 Python 的Ellipsis
對象可以在子列表中提供,以啟用廣播,如上麵的公式部分所述。例子:
# trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) # batch matrix multiplication >>> As = torch.randn(3,2,5) >>> Bs = torch.randn(3,5,4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3,5,4) >>> l = torch.randn(2,5) >>> r = torch.randn(2,4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])
參數:
相關用法
- Python PyTorch eigvals用法及代碼示例
- Python PyTorch eigvalsh用法及代碼示例
- Python PyTorch eig用法及代碼示例
- Python PyTorch eigh用法及代碼示例
- Python PyTorch enable_grad用法及代碼示例
- Python PyTorch equal用法及代碼示例
- Python PyTorch eq用法及代碼示例
- Python PyTorch erfc用法及代碼示例
- Python PyTorch exp用法及代碼示例
- Python PyTorch empty_like用法及代碼示例
- Python PyTorch expires用法及代碼示例
- Python PyTorch effect_names用法及代碼示例
- Python PyTorch entr用法及代碼示例
- Python PyTorch embedding用法及代碼示例
- Python PyTorch empty_strided用法及代碼示例
- Python PyTorch emit_nvtx用法及代碼示例
- Python PyTorch expm1用法及代碼示例
- Python PyTorch export用法及代碼示例
- Python PyTorch exp2用法及代碼示例
- Python PyTorch embedding_bag用法及代碼示例
- Python PyTorch erfinv用法及代碼示例
- Python PyTorch extract_archive用法及代碼示例
- Python PyTorch empty用法及代碼示例
- Python PyTorch eye用法及代碼示例
- Python PyTorch eye_用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.einsum。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。