當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch einsum用法及代碼示例


本文簡要介紹python語言中 torch.einsum 的用法。

用法:

torch.einsum(equation, *operands) → Tensor

參數

  • equation(string) -愛因斯坦求和的下標。

  • operands(List[Tensor]) -計算愛因斯坦和的張量。

將輸入 operands 的元素的乘積與使用基於 Einstein 求和約定的符號指定的維度相加。

Einsum 允許計算許多常見的多維線性代數數組運算,通過以基於愛因斯坦求和約定的速記格式表示它們,由 equation 給出。下麵說明了這種格式的細節,但總體思路是用一些下標來標記輸入 operands 的每個維度,並定義哪些下標是輸出的一部分。然後,通過將 operands 的元素的乘積沿下標不屬於輸出的一部分的維度求和來計算輸出。例如,矩陣乘法可以使用 einsum 計算為 torch.einsum(“ij,jk->ik”, A, B) 。這裏,j 是求和下標,i 和 k 是輸出下標(有關原因的更多詳細信息,請參閱下麵的部分)。

方程:

equation 字符串以與維度相同的順序為輸入 operands 的每個維度指定下標([a-zA-Z] 中的字母),用逗號(',')分隔每個操作數的下標,例如‘ij,jk’ 為兩個二維操作數指定下標。標有相同下標的維度必須是可廣播的,即它們的大小必須匹配或為 1 。例外情況是,如果為相同的輸入操作數重複下標,在這種情況下,此操作數的標有此下標的維度必須在大小上匹配,並且操作數將被其沿這些維度的對角線替換。在equation 中恰好出現一次的下標將成為輸出的一部分,按字母升序排序。輸出是通過將輸入 operands 元素相乘,其尺寸基於下標對齊,然後將其下標不屬於輸出的尺寸相加來計算的。

可選地,輸出下標可以通過在等式末尾添加一個箭頭 ('->') 來明確定義,然後是輸出的下標。例如,以下等式計算矩陣乘法的轉置:“ij,jk->ki”。對於某些輸入操作數,輸出下標必須至少出現一次,而對於輸出則至多出現一次。

省略號(‘…’)可以用來代替下標來廣播省略號所覆蓋的維度。每個輸入操作數最多可以包含一個省略號,該省略號將覆蓋下標未覆蓋的維度,例如對於 5 個維度的輸入操作數,方程 ‘ab…c’ 中的省略號涵蓋第三個和第四個維度。省略號不需要在 operands 中覆蓋相同數量的維度,但省略號的 ‘shape’(它們覆蓋的維度的大小)必須一起廣播。如果未使用箭頭(“->”)符號顯式定義輸出,則省略號將首先出現在輸出中(最左側的維度),位於輸入操作數僅出現一次的下標標簽之前。例如以下等式實現批量矩陣乘法‘…ij,…jk’

最後幾點說明:等式可能包含不同元素(下標、省略號、箭頭和逗號)之間的空格,但 ‘…’ 之類的內容無效。空字符串 ‘’ 對標量操作數有效。

注意

torch.einsum 處理省略號(“...”)與 NumPy 的不同之處在於它允許對省略號覆蓋的維度求和,也就是說,省略號不需要成為輸出的一部分。

注意

此函數不會優化給定的表達式,因此相同計算的不同公式可能會運行得更快或消耗更少的內存。 opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) 等項目可以為您優化公式。

注意

從 PyTorch 1.10 開始,torch.einsum() 還支持子列表格式(請參見下麵的示例)。在此格式中,每個操作數的下標由子列表指定,子列表是 [0, 52) 範圍內的整數列表。這些子列表跟隨它們的操作數,並且額外的子列表可以出現在輸入的末尾以指定輸出的下標。 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out]) 。 Python 的 Ellipsis 對象可以在子列表中提供,以啟用廣播,如上麵的公式部分所述。

例子:

# trace
>>> torch.einsum('ii', torch.randn(4, 4))
tensor(-1.2104)

# diagonal
>>> torch.einsum('ii->i', torch.randn(4, 4))
tensor([-0.1034,  0.7952, -0.2433,  0.4545])

# outer product
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('i,j->ij', x, y)
tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
        [-0.3744,  0.9381,  1.2685, -1.6070],
        [ 0.7208, -1.8058, -2.4419,  3.0936],
        [ 0.1713, -0.4291, -0.5802,  0.7350],
        [ 0.5704, -1.4290, -1.9323,  2.4480]])

# batch matrix multiplication
>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

# with sublist format and ellipsis
>>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

# batch permute
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2, 3, 5, 4])

# equivalent to torch.nn.functional.bilinear
>>> A = torch.randn(3,5,4)
>>> l = torch.randn(2,5)
>>> r = torch.randn(2,4)
>>> torch.einsum('bn,anm,bm->ba', l, A, r)
tensor([[-0.3430, -5.2405,  0.4494],
        [ 0.3311,  5.5201, -3.0356]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.einsum。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。