當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch inner用法及代碼示例


本文簡要介紹python語言中 torch.inner 的用法。

用法:

torch.inner(input, other, *, out=None) → Tensor

參數

  • input(Tensor) -第一個輸入張量

  • other(Tensor) -第二個輸入張量

關鍵字參數

out(Tensor,可選的) -將結果寫入的可選輸出張量。輸出形狀為 input.shape[:-1] + other.shape[:-1]

計算一維張量的點積。對於更高維度,將 inputother 中元素的乘積沿其最後一個維度求和。

注意

如果 inputother 是標量,則結果等效於 torch.mul(input, other)

如果 inputother 都是非標量,則它們最後一個維度的大小必須匹配,結果等價於 torch.tensordot(input, other, dims=([-1], [-1]))

例子:

# Dot product
>>> torch.inner(torch.tensor([1, 2, 3]), torch.tensor([0, 2, 1]))
tensor(7)

# Multidimensional input tensors
>>> a = torch.randn(2, 3)
>>> a
tensor([[0.8173, 1.0874, 1.1784],
        [0.3279, 0.1234, 2.7894]])
>>> b = torch.randn(2, 4, 3)
>>> b
tensor([[[-0.4682, -0.7159,  0.1506],
        [ 0.4034, -0.3657,  1.0387],
        [ 0.9892, -0.6684,  0.1774],
        [ 0.9482,  1.3261,  0.3917]],

        [[ 0.4537,  0.7493,  1.1724],
        [ 0.2291,  0.5749, -0.2267],
        [-0.7920,  0.3607, -0.3701],
        [ 1.3666, -0.5850, -1.7242]]])
>>> torch.inner(a, b)
tensor([[[-0.9837,  1.1560,  0.2907,  2.6785],
        [ 2.5671,  0.5452, -0.6912, -1.5509]],

        [[ 0.1782,  2.9843,  0.7366,  1.5672],
        [ 3.5115, -0.4864, -1.2476, -4.4337]]])

# Scalar input
>>> torch.inner(a, torch.tensor(2))
tensor([[1.6347, 2.1748, 2.3567],
        [0.6558, 0.2469, 5.5787]])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.inner。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。