当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch solve用法及代码示例


本文简要介绍python语言中 torch.solve 的用法。

用法:

torch.solve(input, A, *, out=None)

参数

  • input(Tensor) -大小为 的输入矩阵 ,其中 是零个或多个批处理维度。

  • A(Tensor) -输入大小为 的方阵,其中 是零个或多个批处理维度。

关键字参数

out((Tensor,Tensor),可选的) -可选的输出元组。

此函数返回由 和 A 的 LU 分解表示的线性方程组的解,以便作为命名元组 solution, LU

LU 包含 LU 因子,用于 A 的 LU 分解。

torch.solve(B, A) 可以接受 2D 输入 B, A 或作为 2D 矩阵批次的输入。如果输入是批处理,则返回批处理输出 solution, LU

支持实值和complex-valued 输入。

警告

torch.solve() 已弃用,取而代之的是 torch.linalg.solve() ,并将在未来的 PyTorch 版本中删除。 torch.linalg.solve() 的参数相反,并且不返回输入的 LU 分解。要获得 LU 分解,请参阅 torch.lu() ,它可以与 torch.lu_solve() torch.lu_unpack() 一起使用。

X = torch.solve(B, A).solution 应替换为

X = torch.linalg.solve(A, B)

注意

无论原始步幅如何,返回的矩阵 solutionLU 将被转置,即分别具有像 B.contiguous().transpose(-1, -2).stride()A.contiguous().transpose(-1, -2).stride() 这样的步幅。

例子:

>>> A = torch.tensor([[6.80, -2.11,  5.66,  5.97,  8.23],
...                   [-6.05, -3.30,  5.36, -4.44,  1.08],
...                   [-0.45,  2.58, -2.70,  0.27,  9.04],
...                   [8.32,  2.71,  4.35,  -7.17,  2.14],
...                   [-9.67, -5.14, -7.26,  6.08, -6.87]]).t()
>>> B = torch.tensor([[4.02,  6.19, -8.22, -7.57, -3.03],
...                   [-1.56,  4.00, -8.67,  1.75,  2.86],
...                   [9.81, -4.09, -4.57, -8.61,  8.99]]).t()
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, torch.mm(A, X))
tensor(1.00000e-06 *
       7.0977)

>>> # Batched solver example
>>> A = torch.randn(2, 3, 1, 4, 4)
>>> B = torch.randn(2, 3, 1, 4, 6)
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, A.matmul(X))
tensor(1.00000e-06 *
   3.6386)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.solve。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。