當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch qr用法及代碼示例

本文簡要介紹python語言中 torch.linalg.qr 的用法。

用法:

torch.linalg.qr(A, mode='reduced', *, out=None)

參數

  • A(Tensor) -形狀為 (*, m, n) 的張量,其中 * 是零個或多個批次維度。

  • mode(str,可選的) -‘reduced’‘complete’‘r’ 之一。控製返回張量的形狀。默認值:‘reduced’

關鍵字參數

out(tuple,可選的) -兩個張量的輸出元組。如果 None 則忽略。默認值:None

返回

命名元組 (Q, R)

計算矩陣的 QR 分解。

\mathbb{K} \mathbb{R} 或者\mathbb{C} , 這全二維分解矩陣的A \in \mathbb{K}^{m \times n} 定義為

其中 在真實情況下是正交的,在複雜情況下是單一的,而 是上三角形。

什麽時候m > n(高矩陣),如R是上三角形,最後一個m - n行為零。在這種情況下,我們可以刪除最後一個m - nQ形成減少 QR 分解

n >= m(寬矩陣)時,簡化的 QR 分解與完整的 QR 分解一致。

支持 float、double、cfloat 和 cdouble dtypes 的輸入。還支持批量矩陣,如果 A 是批量矩陣,則輸出具有相同的批量維度。

參數mode 在完全和簡化的 QR 分解之間進行選擇。如果 A 具有形狀 (*, m, n) ,表示 k = min(m, n)

  • mode = ‘reduced’(默認):分別返回形狀 (*, m, k)(*, k, n)(Q, R)

  • mode = ‘complete’ :分別返回形狀 (*, m, m)(*, m, n)(Q, R)

  • mode = ‘r’ :僅計算簡化後的 R 。返回 (Q, R)Q 為空和 R 形狀為 (*, k, n)

numpy.linalg.qr 的區別:

  • mode = ‘raw’ 未實現。

  • numpy.linalg.qr 不同,此函數始終返回兩個張量的元組。當 mode = ‘r’ 時,Q 張量是一個空張量。此行為可能會在未來的PyTorch 版本中發生變化。

注意

R 的對角線元素不一定是正數。

注意

mode = ‘r’ 不支持反向傳播。請改用mode = ‘reduced’

警告

A 的前k = min(m, n) 列線性無關時,QR 分解僅在R 的對角線符號處唯一。如果不是這種情況,不同的平台(如 NumPy)或不同設備上的輸入可能會產生不同的有效分解。

警告

僅當 A 中每個矩陣的前 k = min(m, n) 列線性無關時,才支持梯度計算。如果不滿足這個條件,不會拋出錯誤,但是產生的梯度會不正確。這是因為 QR 分解在這些點上是不可微的。

例子:

>>> A = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]])
>>> Q, R = torch.linalg.qr(A)
>>> Q
tensor([[-0.8571,  0.3943,  0.3314],
        [-0.4286, -0.9029, -0.0343],
        [ 0.2857, -0.1714,  0.9429]])
>>> R
tensor([[ -14.0000,  -21.0000,   14.0000],
        [   0.0000, -175.0000,   70.0000],
        [   0.0000,    0.0000,  -35.0000]])
>>> (Q @ R).round()
tensor([[  12.,  -51.,    4.],
        [   6.,  167.,  -68.],
        [  -4.,   24.,  -41.]])
>>> (Q.T @ Q).round()
tensor([[ 1.,  0.,  0.],
        [ 0.,  1., -0.],
        [ 0., -0.,  1.]])
>>> Q2, R2 = torch.linalg.qr(A, mode='r')
>>> Q2
tensor([])
>>> torch.equal(R, R2)
True
>>> A = torch.randn(3, 4, 5)
>>> Q, R = torch.linalg.qr(A, mode='complete')
>>> torch.dist(Q @ R, A)
tensor(1.6099e-06)
>>> torch.dist(Q.transpose(-2, -1) @ Q, torch.eye(4))
tensor(6.2158e-07)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.linalg.qr。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。