当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch tensorsolve用法及代码示例


本文简要介绍python语言中 torch.linalg.tensorsolve 的用法。

用法:

torch.linalg.tensorsolve(A, B, dims=None, *, out=None) → Tensor

参数

  • A(Tensor) -要解决的张量。它的形状必须满足 prod( A .shape[: B .ndim]) == prod( A .shape[ B .ndim:])

  • B(Tensor) -形状的张量 A .shape[ B .ndim]

  • dims(元组[int],可选的) -要移动的 A 的尺寸。如果是 None ,则不移动尺寸。默认值:None

关键字参数

out(Tensor,可选的) -输出张量。如果 None 则忽略。默认值:None

抛出

RuntimeError - 如果重新整形的 A .view(m, m) 与上述 m 不可逆,或者第一个 ind 维度的乘积不等于其余维度的乘积。

计算解决方案 X 到系统 torch.tensordot(A, X) = B

如果mA 的第一个B .ndim 维度的乘积,而n 是其余维度的乘积,则此函数期望mn 相等。

返回的张量 x 满足 tensordot( A , x, dims=x.ndim) == Bx 具有形状 A [B.ndim:]

如果指定dimsA 将被重新整形为

A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))

支持 float、double、cfloat 和 cdouble dtypes 的输入。

例子:

>>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4))
>>> B = torch.randn(2 * 3, 4)
>>> X = torch.linalg.tensorsolve(A, B)
>>> X.shape
torch.Size([2, 3, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B)
True

>>> A = torch.randn(6, 4, 4, 3, 2)
>>> B = torch.randn(4, 3, 2)
>>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2))
>>> X.shape
torch.Size([6, 4])
>>> A = A.permute(1, 3, 4, 0, 2)
>>> A.shape[B.ndim:]
torch.Size([6, 4])
>>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6)
True

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.linalg.tensorsolve。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。