當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch multi_dot用法及代碼示例


本文簡要介紹python語言中 torch.linalg.multi_dot 的用法。

用法:

torch.linalg.multi_dot(tensors, *, out=None)

參數

tensors(Sequence[Tensor]) -兩個或多個張量相乘。第一個和最後一個張量可以是 1D 或 2D。每個其他張量必須是 2D 的。

關鍵字參數

out(Tensor,可選的) -輸出張量。如果 None 則忽略。默認值:None

通過重新排序乘法來有效地將兩個或多個矩陣相乘,以便執行最少的算術運算。

支持 float、double、cfloat 和 cdouble dtypes 的輸入。此函數不支持批量輸入。

tensors 中的每個張量都必須是 2D,但第一個和最後一個可能是 1D 的除外。如果第一個張量是形狀為 (n,) 的一維向量,則將其視為形狀為 (1, n) 的行向量,類似地,如果最後一個張量是形狀為 (n,) 的一維向量,則將其視為形狀為列向量(n, 1)

如果第一個和最後一個張量是矩陣,則輸出將是一個矩陣。但是,如果其中任何一個是一維向量,則輸出將是一維向量。

numpy.linalg.multi_dot 的區別:

  • numpy.linalg.multi_dot 不同,第一個和最後一個張量必須是 1D 或 2D,而 NumPy 允許它們是 nD

警告

此函數不廣播。

注意

此函數通過在計算最佳矩陣乘法順序後鏈接 torch.mm() 調用來實現。

注意

將兩個矩陣與形狀 (a, b)(b, c) 相乘的成本是 a * b * c 。給定矩陣 A , B , C 形狀分別為 (10, 100) , (100, 5) , (5, 50) ,我們可以計算不同乘法順序的成本如下:

在這種情況下,先將 AB 相乘,然後再乘以 C 會快 10 倍。

例子:

>>> from torch.linalg import multi_dot

>>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])])
tensor(8)
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])])
tensor([8])
>>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])])
tensor([[8]])

>>> A = torch.arange(2 * 3).view(2, 3)
>>> B = torch.arange(3 * 2).view(3, 2)
>>> C = torch.arange(2 * 2).view(2, 2)
>>> multi_dot((A, B, C))
tensor([[ 26,  49],
        [ 80, 148]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.linalg.multi_dot。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。