當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch householder_product用法及代碼示例


本文簡要介紹python語言中 torch.linalg.householder_product 的用法。

用法:

torch.linalg.householder_product(A, tau, *, out=None) → Tensor

參數

  • A(Tensor) -形狀為 (*, m, n) 的張量,其中 * 是零個或多個批次維度。

  • tau(Tensor) -形狀為 (*, k) 的張量,其中 * 是零個或多個批次維度。

關鍵字參數

out(Tensor,可選的) -輸出張量。如果 None 則忽略。默認值:None

拋出

RuntimeError - 如果 A 不滿足要求 m >= n ,或 tau 不滿足要求 n >= k

計算 Householder 矩陣乘積的前 n 列。

假設 ,對於具有列 的矩陣 和帶有 的向量 ,此函數計算第一個 列矩陣

其中 m維單位矩陣, 複數時的共軛轉置, 是實值時的轉置。

有關詳細信息,請參閱Representation of Orthogonal or Unitary Matrices

支持 float、double、cfloat 和 cdouble dtypes 的輸入。還支持矩陣批次,如果輸入是矩陣批次,則輸出具有相同的批次尺寸。

注意

此函數僅使用嚴格低於 A 主對角線的值。其他值被忽略。

例子:

>>> A = torch.randn(2, 2)
>>> h, tau = torch.geqrf(A)
>>> Q = torch.linalg.householder_product(h, tau)
>>> torch.dist(Q, torch.linalg.qr(A).Q)
tensor(0.)

>>> h = torch.randn(3, 2, 2, dtype=torch.complex128)
>>> tau = torch.randn(3, 1, dtype=torch.complex128)
>>> Q = torch.linalg.householder_product(h, tau)
>>> Q
tensor([[[ 1.8034+0.4184j,  0.2588-1.0174j],
        [-0.6853+0.7953j,  2.0790+0.5620j]],

        [[ 1.4581+1.6989j, -1.5360+0.1193j],
        [ 1.3877-0.6691j,  1.3512+1.3024j]],

        [[ 1.4766+0.5783j,  0.0361+0.6587j],
        [ 0.6396+0.1612j,  1.3693+0.4481j]]], dtype=torch.complex128)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.linalg.householder_product。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。