当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch lstsq用法及代码示例


本文简要介绍python语言中 torch.linalg.lstsq 的用法。

用法:

torch.linalg.lstsq(A, B, rcond=None, *, driver=None)

参数

  • A(Tensor) -形状为 (*, m, n) 的 lhs 张量,其中 * 是零个或多个批次维度。

  • B(Tensor) -形状为 (*, m, k) 的 rhs 张量,其中 * 是零个或多个批次维度。

  • rcond(float,可选的) -用于确定 A 的有效等级。如果 rcond = Nonercond 设置为 A 的 dtype 乘以 max(m, n) 的机器精度。默认值:None

关键字参数

driver(str,可选的) -要使用的 LAPACK/MAGMA 方法的名称。如果 None‘gelsy’ 用于 CPU 输入, ‘gels’ 用于 CUDA 输入。默认值:None

返回

命名元组 (solution, residuals, rank, singular_values)

计算线性方程组的最小二乘问题的解。

\mathbb{K} \mathbb{R} 或者\mathbb{C} , 这最小二乘问题对于线性系统AX = B A \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k} 定义为

其中 表示 Frobenius 范数。

支持 float、double、cfloat 和 cdouble dtypes 的输入。还支持矩阵批次,如果输入是矩阵批次,则输出具有相同的批次尺寸。

driver 选择将使用的 LAPACK/MAGMA 函数。对于 CPU 输入,有效值为 ‘gels’‘gelsy’‘gelsd‘gelss’ 。对于 CUDA 输入,唯一有效的驱动程序是 ‘gels’ ,它假定 A 是满秩的。要在 CPU 上选择最佳驱动程序,请考虑:

  • 如果A 条件良好(它的condition number 不是太大),或者您不介意一些精度损失。

    • 对于一般矩阵:‘gelsy’(带旋转的 QR)(默认)

    • 如果A 是满秩:‘gels’ (QR)

  • 如果 A 条件不佳。

    • ‘gelsd’(三对角缩减和 SVD)

    • 但是,如果您遇到内存问题:‘gelss’(完整的 SVD)。

另见full description of these drivers

driver 是(‘gelsy’‘gelsd’‘gelss’)之一时,rcond 用于确定A 中矩阵的有效秩。在这种情况下,如果 A 的降序奇异值,如果 ,则 将向下舍入为零。如果 rcond = None(默认),rcond 设置为 A dtype 的机器精度。

此函数在四个张量 (solution, residuals, rank, singular_values) 的命名元组中返回问题的解决方案和一些额外信息。对于形状为 (*, m, n)(*, m, k) 的输入 AB ,它包含

  • solution :最小二乘解。它的形状为 (*, n, k)

  • residuals :解的平方残差,即 。它的形状等于 A 的批量尺寸。当m > nA 中的每个矩阵都是满秩时计算它,否则它是一个空张量。如果A 是一批矩阵并且该批中的任何矩阵不是满秩的,则返回一个空张量。此行为可能会在未来的PyTorch 版本中发生变化。

  • rankA 中矩阵的秩张量。它的形状等于 A 的批量尺寸。当driver 是(‘gelsy’‘gelsd’‘gelss’)之一时计算,否则为空张量。

  • singular_valuesA 中矩阵的奇异值的张量。它的形状为 (*, min(m, n)) 。当 driver 是 (‘gelsd’ , ‘gelss’ ) 之一时计算它,否则它是一个空张量。

注意

此函数计算 X = A .pinverse() @ B 比单独执行计算更快且数值更稳定。

警告

rcond 的默认值可能会在未来的 PyTorch 版本中更改。因此,建议使用固定值以避免潜在的破坏性变化。

例子:

>>> A = torch.tensor([[[10, 2, 3], [3, 10, 5], [5, 6, 12]]], dtype=torch.float) # shape (1, 3, 3)
>>> B = torch.tensor([[[2, 5, 1], [3, 2, 1], [5, 1, 9]],
                      [[4, 2, 9], [2, 0, 3], [2, 5, 3]]], dtype=torch.float) # shape (2, 3, 3)
>>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3)
>>> torch.dist(X, torch.linalg.pinv(A) @ B)
tensor(2.0862e-07)

>>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values
>>> torch.dist(S, torch.linalg.svdvals(A))
tensor(5.7220e-06)

>>> A[:, 0].zero_()  # Decrease the rank of A
>>> rank = torch.linalg.lstsq(A, B).rank
>>> rank
tensor([2])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.linalg.lstsq。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。