當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch lu_unpack用法及代碼示例


本文簡要介紹python語言中 torch.lu_unpack 的用法。

用法:

torch.lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None)

參數

  • LU_data(Tensor) -打包的 LU 分解數據

  • LU_pivots(Tensor) -打包的 LU 分解支點

  • unpack_data(bool) -指示是否應解包數據的標誌。如果 False ,則返回的 LUNone 。默認值:True

  • unpack_pivots(bool) -指示是否應將主元解壓縮為置換矩陣 P 的標誌。如果 False ,則返回的 PNone 。默認值:True

  • out(tuple,可選的) -用於輸出 (P, L, U) 的三個張量的元組。

將數據解包並從張量的 LU 因式分解為張量 LU 以及排列張量 P 以便 LU_data, LU_pivots = (P @ L @ U).lu()

返回張量的元組為(the P tensor (permutation matrix), the L tensor, the U tensor).

注意

P.dtype == LU_data.dtypeP.dtype 不是整數類型,因此可以使用 P 的矩陣乘積而不將其轉換為浮點類型。

例子:

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = A.lu()
>>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
>>>
>>> # can recover A from factorization
>>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))

>>> # LU factorization of a rectangular matrix:
>>> A = torch.randn(2, 3, 2)
>>> A_LU, pivots = A.lu()
>>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
>>> P
tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]],

        [[0., 0., 1.],
         [0., 1., 0.],
         [1., 0., 0.]]])
>>> A_L
tensor([[[ 1.0000,  0.0000],
         [ 0.4763,  1.0000],
         [ 0.3683,  0.1135]],

        [[ 1.0000,  0.0000],
         [ 0.2957,  1.0000],
         [-0.9668, -0.3335]]])
>>> A_U
tensor([[[ 2.1962,  1.0881],
         [ 0.0000, -0.8681]],

        [[-1.0947,  0.3736],
         [ 0.0000,  0.5718]]])
>>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
>>> torch.norm(A_ - A)
tensor(2.9802e-08)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.lu_unpack。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。