當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch tensorsolve用法及代碼示例


本文簡要介紹python語言中 torch.linalg.tensorsolve 的用法。

用法:

torch.linalg.tensorsolve(A, B, dims=None, *, out=None) → Tensor

參數

  • A(Tensor) -要解決的張量。它的形狀必須滿足 prod( A .shape[: B .ndim]) == prod( A .shape[ B .ndim:])

  • B(Tensor) -形狀的張量 A .shape[ B .ndim]

  • dims(元組[int],可選的) -要移動的 A 的尺寸。如果是 None ,則不移動尺寸。默認值:None

關鍵字參數

out(Tensor,可選的) -輸出張量。如果 None 則忽略。默認值:None

拋出

RuntimeError - 如果重新整形的 A .view(m, m) 與上述 m 不可逆,或者第一個 ind 維度的乘積不等於其餘維度的乘積。

計算解決方案 X 到係統 torch.tensordot(A, X) = B

如果mA 的第一個B .ndim 維度的乘積,而n 是其餘維度的乘積,則此函數期望mn 相等。

返回的張量 x 滿足 tensordot( A , x, dims=x.ndim) == Bx 具有形狀 A [B.ndim:]

如果指定dimsA 將被重新整形為

A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))

支持 float、double、cfloat 和 cdouble dtypes 的輸入。

例子:

>>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4))
>>> B = torch.randn(2 * 3, 4)
>>> X = torch.linalg.tensorsolve(A, B)
>>> X.shape
torch.Size([2, 3, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B)
True

>>> A = torch.randn(6, 4, 4, 3, 2)
>>> B = torch.randn(4, 3, 2)
>>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2))
>>> X.shape
torch.Size([6, 4])
>>> A = A.permute(1, 3, 4, 0, 2)
>>> A.shape[B.ndim:]
torch.Size([6, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6)
True

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.linalg.tensorsolve。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。