當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch baddbmm用法及代碼示例


本文簡要介紹python語言中 torch.baddbmm 的用法。

用法:

torch.baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) → Tensor

參數

  • input(Tensor) -要添加的張量

  • batch1(Tensor) -第一批要相乘的矩陣

  • batch2(Tensor) -第二批要相乘的矩陣

關鍵字參數

  • beta(數字,可選的) -input ( ) 的乘數

  • alpha(數字,可選的) - ( ) 的乘數

  • out(Tensor,可選的) -輸出張量。

batch1batch2 中的矩陣執行批次 matrix-matrix 乘積。 input 添加到最終結果中。

batch1batch2 必須是 3-D 張量,每個張量都包含相同數量的矩陣。

如果batch1 張量,batch2 張量,那麽input必須是可廣播的 張量並且out將是 張量。 alphabeta 的含義與 torch.addbmm() 中使用的比例因子相同。

如果beta為0,則input將被忽略,其中的naninf不會被傳播。

對於 FloatTensorDoubleTensor 類型的輸入,參數 betaalpha 必須是實數,否則它們應該是整數。

該運算符支持 TensorFloat32。

例子:

>>> M = torch.randn(10, 3, 5)
>>> batch1 = torch.randn(10, 3, 4)
>>> batch2 = torch.randn(10, 4, 5)
>>> torch.baddbmm(M, batch1, batch2).size()
torch.Size([10, 3, 5])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.baddbmm。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。