當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch tensordot用法及代碼示例


本文簡要介紹python語言中 torch.tensordot 的用法。

用法:

torch.tensordot(a, b, dims=2, out=None)

參數

  • a(Tensor) -左張量收縮

  • b(Tensor) -右張量收縮

  • dims(int或者元組[List[int],List[int]] 或者List[List[int]]包含兩個列表或者Tensor) - 要收縮的維度數量或明確的維度列表ab分別

返回 a 和 b 在多個維度上的收縮。

tensordot 實現廣義矩陣乘積。

當使用非負整數參數 dims = 調用,並且 ab 的維數分別為 時,tensordot() 計算

當使用列表形式的dims調用時,給定的維度將被收縮以代替a的最後一個 的第一個 。這些維度中的大小必須匹配,但 tensordot() 將處理廣播維度。

例子:

>>> a = torch.arange(60.).reshape(3, 4, 5)
>>> b = torch.arange(24.).reshape(4, 3, 2)
>>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
tensor([[4400., 4730.],
        [4532., 4874.],
        [4664., 5018.],
        [4796., 5162.],
        [4928., 5306.]])

>>> a = torch.randn(3, 4, 5, device='cuda')
>>> b = torch.randn(4, 5, 6, device='cuda')
>>> c = torch.tensordot(a, b, dims=2).cpu()
tensor([[ 8.3504, -2.5436,  6.2922,  2.7556, -1.0732,  3.2741],
        [ 3.3161,  0.0704,  5.0187, -0.4079, -4.3126,  4.8744],
        [ 0.8223,  3.9445,  3.2168, -0.2400,  3.4117,  1.7780]])

>>> a = torch.randn(3, 5, 4, 6)
>>> b = torch.randn(6, 4, 5, 3)
>>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
tensor([[  7.7193,  -2.4867, -10.3204],
        [  1.5513, -14.4737,  -6.5113],
        [ -0.2850,   4.2573,  -3.5997]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.tensordot。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。