当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch householder_product用法及代码示例


本文简要介绍python语言中 torch.linalg.householder_product 的用法。

用法:

torch.linalg.householder_product(A, tau, *, out=None) → Tensor

参数

  • A(Tensor) -形状为 (*, m, n) 的张量,其中 * 是零个或多个批次维度。

  • tau(Tensor) -形状为 (*, k) 的张量,其中 * 是零个或多个批次维度。

关键字参数

out(Tensor,可选的) -输出张量。如果 None 则忽略。默认值:None

抛出

RuntimeError - 如果 A 不满足要求 m >= n ,或 tau 不满足要求 n >= k

计算 Householder 矩阵乘积的前 n 列。

假设 ,对于具有列 的矩阵 和带有 的向量 ,此函数计算第一个 列矩阵

其中 m维单位矩阵, 复数时的共轭转置, 是实值时的转置。

有关详细信息,请参阅Representation of Orthogonal or Unitary Matrices

支持 float、double、cfloat 和 cdouble dtypes 的输入。还支持矩阵批次,如果输入是矩阵批次,则输出具有相同的批次尺寸。

注意

此函数仅使用严格低于 A 主对角线的值。其他值被忽略。

例子:

>>> A = torch.randn(2, 2)
>>> h, tau = torch.geqrf(A)
>>> Q = torch.linalg.householder_product(h, tau)
>>> torch.dist(Q, torch.linalg.qr(A).Q)
tensor(0.)

>>> h = torch.randn(3, 2, 2, dtype=torch.complex128)
>>> tau = torch.randn(3, 1, dtype=torch.complex128)
>>> Q = torch.linalg.householder_product(h, tau)
>>> Q
tensor([[[ 1.8034+0.4184j,  0.2588-1.0174j],
        [-0.6853+0.7953j,  2.0790+0.5620j]],

        [[ 1.4581+1.6989j, -1.5360+0.1193j],
        [ 1.3877-0.6691j,  1.3512+1.3024j]],

        [[ 1.4766+0.5783j,  0.0361+0.6587j],
        [ 0.6396+0.1612j,  1.3693+0.4481j]]], dtype=torch.complex128)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.linalg.householder_product。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。