当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch qr用法及代码示例


本文简要介绍python语言中 torch.linalg.qr 的用法。

用法:

torch.linalg.qr(A, mode='reduced', *, out=None)

参数

  • A(Tensor) -形状为 (*, m, n) 的张量,其中 * 是零个或多个批次维度。

  • mode(str,可选的) -‘reduced’‘complete’‘r’ 之一。控制返回张量的形状。默认值:‘reduced’

关键字参数

out(tuple,可选的) -两个张量的输出元组。如果 None 则忽略。默认值:None

返回

命名元组 (Q, R)

计算矩阵的 QR 分解。

\mathbb{K} \mathbb{R} 或者\mathbb{C} , 这全二维分解矩阵的A \in \mathbb{K}^{m \times n} 定义为

其中 在真实情况下是正交的,在复杂情况下是单一的,而 是上三角形。

什么时候m > n(高矩阵),如R是上三角形,最后一个m - n行为零。在这种情况下,我们可以删除最后一个m - nQ形成减少 QR 分解

n >= m(宽矩阵)时,简化的 QR 分解与完整的 QR 分解一致。

支持 float、double、cfloat 和 cdouble dtypes 的输入。还支持批量矩阵,如果 A 是批量矩阵,则输出具有相同的批量维度。

参数mode 在完全和简化的 QR 分解之间进行选择。如果 A 具有形状 (*, m, n) ,表示 k = min(m, n)

  • mode = ‘reduced’(默认):分别返回形状 (*, m, k)(*, k, n)(Q, R)

  • mode = ‘complete’ :分别返回形状 (*, m, m)(*, m, n)(Q, R)

  • mode = ‘r’ :仅计算简化后的 R 。返回 (Q, R)Q 为空和 R 形状为 (*, k, n)

numpy.linalg.qr 的区别:

  • mode = ‘raw’ 未实现。

  • numpy.linalg.qr 不同,此函数始终返回两个张量的元组。当 mode = ‘r’ 时,Q 张量是一个空张量。此行为可能会在未来的PyTorch 版本中发生变化。

注意

R 的对角线元素不一定是正数。

注意

mode = ‘r’ 不支持反向传播。请改用mode = ‘reduced’

警告

A 的前k = min(m, n) 列线性无关时,QR 分解仅在R 的对角线符号处唯一。如果不是这种情况,不同的平台(如 NumPy)或不同设备上的输入可能会产生不同的有效分解。

警告

仅当 A 中每个矩阵的前 k = min(m, n) 列线性无关时,才支持梯度计算。如果不满足这个条件,不会抛出错误,但是产生的梯度会不正确。这是因为 QR 分解在这些点上是不可微的。

例子:

>>> A = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]])
>>> Q, R = torch.linalg.qr(A)
>>> Q
tensor([[-0.8571,  0.3943,  0.3314],
        [-0.4286, -0.9029, -0.0343],
        [ 0.2857, -0.1714,  0.9429]])
>>> R
tensor([[ -14.0000,  -21.0000,   14.0000],
        [   0.0000, -175.0000,   70.0000],
        [   0.0000,    0.0000,  -35.0000]])
>>> (Q @ R).round()
tensor([[  12.,  -51.,    4.],
        [   6.,  167.,  -68.],
        [  -4.,   24.,  -41.]])
>>> (Q.T @ Q).round()
tensor([[ 1.,  0.,  0.],
        [ 0.,  1., -0.],
        [ 0., -0.,  1.]])
>>> Q2, R2 = torch.linalg.qr(A, mode='r')
>>> Q2
tensor([])
>>> torch.equal(R, R2)
True
>>> A = torch.randn(3, 4, 5)
>>> Q, R = torch.linalg.qr(A, mode='complete')
>>> torch.dist(Q @ R, A)
tensor(1.6099e-06)
>>> torch.dist(Q.transpose(-2, -1) @ Q, torch.eye(4))
tensor(6.2158e-07)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.linalg.qr。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。