当前位置: 首页>>代码示例>>Python>>正文


Python torch.trtrs方法代码示例

本文整理汇总了Python中torch.trtrs方法的典型用法代码示例。如果您正苦于以下问题:Python torch.trtrs方法的具体用法?Python torch.trtrs怎么用?Python torch.trtrs使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.trtrs方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _ssor_preconditioner

# 需要导入模块: import torch [as 别名]
# 或者: from torch import trtrs [as 别名]
def _ssor_preconditioner(self, lhs_mat, mat):
        if lhs_mat.ndimension() == 2:
            DL = lhs_mat.tril()
            D = lhs_mat.diag()
            upper_part = (1 / D).expand_as(DL).mul(DL.t())
            Minv_times_mat = torch.trtrs(torch.trtrs(mat, DL, upper=False)[0], upper_part)[0]

        elif lhs_mat.ndimension() == 3:
            if mat.size(0) == 1 and lhs_mat.size(0) != 1:
                mat = mat.expand(*([lhs_mat.size(0)] + list(mat.size())[1:]))
            Minv_times_mat = mat.new(*mat.size())
            for i in range(lhs_mat.size(0)):
                DL = lhs_mat[i].tril()
                D = lhs_mat[i].diag()
                upper_part = (1 / D).expand_as(DL).mul(DL.t())
                Minv_times_mat[i].copy_(torch.trtrs(torch.trtrs(mat[i], DL, upper=False)[0], upper_part)[0])

        else:
            raise RuntimeError('Invalid number of dimensions')

        return Minv_times_mat 
开发者ID:jrg365,项目名称:gpytorch,代码行数:23,代码来源:lincg.py

示例2: pre_factor_kkt

# 需要导入模块: import torch [as 别名]
# 或者: from torch import trtrs [as 别名]
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, _ = get_sizes(G, A)

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T           ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]

    U_Q = torch.potrf(Q)
    # partial cholesky of S matrix
    U_S = torch.zeros(neq + nineq, neq + nineq).type_as(Q)

    G_invQ_GT = torch.mm(G, torch.potrs(G.t(), U_Q))
    R = G_invQ_GT
    if neq > 0:
        invQ_AT = torch.potrs(A.t(), U_Q)
        A_invQ_AT = torch.mm(A, invQ_AT)
        G_invQ_AT = torch.mm(G, invQ_AT)

        # TODO: torch.potrf sometimes says the matrix is not PSD but
        # numpy does? I filed an issue at
        # https://github.com/pytorch/pytorch/issues/199
        try:
            U11 = torch.potrf(A_invQ_AT)
        except:
            U11 = torch.Tensor(np.linalg.cholesky(
                A_invQ_AT.cpu().numpy())).type_as(A_invQ_AT)

        # TODO: torch.trtrs is currently not implemented on the GPU
        # and we are using gesv as a workaround.
        U12 = torch.gesv(G_invQ_AT.t(), U11.t())[0]
        U_S[:neq, :neq] = U11
        U_S[:neq, neq:] = U12
        R -= torch.mm(U12.t(), U12)

    return U_Q, U_S, R 
开发者ID:locuslab,项目名称:qpth,代码行数:37,代码来源:single.py

示例3: test_trtrs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import trtrs [as 别名]
def test_trtrs(self):
        def _test_with_size(N, C):
            A = Variable(torch.rand(N, N), requires_grad=True)
            b = Variable(torch.rand(N, C), requires_grad=True)

            for upper, transpose, unitriangular in product((True, False), repeat=3):
                def func(A, b):
                    return torch.trtrs(b, A, upper, transpose, unitriangular)

                gradcheck(func, [A, b])
                gradgradcheck(func, [A, b])

        _test_with_size(S, S + 1)
        _test_with_size(S, S - 1) 
开发者ID:pytorch,项目名称:pytorch,代码行数:16,代码来源:test_autograd.py

示例4: test_trtrs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import trtrs [as 别名]
def test_trtrs(self):
        a = torch.Tensor(((6.80, -2.11,  5.66,  5.97,  8.23),
                        (-6.05, -3.30,  5.36, -4.44,  1.08),
                        (-0.45,  2.58, -2.70,  0.27,  9.04),
                        (8.32,  2.71,  4.35, -7.17,  2.14),
                        (-9.67, -5.14, -7.26,  6.08, -6.87))).t()
        b = torch.Tensor(((4.02,  6.19, -8.22, -7.57, -3.03),
                        (-1.56,  4.00, -8.67,  1.75,  2.86),
                        (9.81, -4.09, -4.57, -8.61,  8.99))).t()

        U = torch.triu(a)
        L = torch.tril(a)

        # solve Ux = b
        x = torch.trtrs(b, U)[0]
        self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
        x = torch.trtrs(b, U, True, False, False)[0]
        self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)

        # solve Lx = b
        x = torch.trtrs(b, L, False)[0]
        self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
        x = torch.trtrs(b, L, False, False, False)[0]
        self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)

        # solve U'x = b
        x = torch.trtrs(b, U, True, True)[0]
        self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
        x = torch.trtrs(b, U, True, True, False)[0]
        self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)

        # solve U'x = b by manual transposition
        y = torch.trtrs(b, U.t(), False, False)[0]
        self.assertLessEqual(x.dist(y), 1e-12)

        # solve L'x = b
        x = torch.trtrs(b, L, False, True)[0]
        self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
        x = torch.trtrs(b, L, False, True, False)[0]
        self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)

        # solve L'x = b by manual transposition
        y = torch.trtrs(b, L.t(), True, False)[0]
        self.assertLessEqual(x.dist(y), 1e-12)

        # test reuse
        res1 = torch.trtrs(b,a)[0]
        ta = torch.Tensor()
        tb = torch.Tensor()
        torch.trtrs(tb,ta,b,a)
        self.assertEqual(res1, tb, 0)
        tb.zero_()
        torch.trtrs(tb,ta,b,a)
        self.assertEqual(res1, tb, 0) 
开发者ID:apaszke,项目名称:pytorch-dist,代码行数:56,代码来源:test_torch.py

示例5: test_trtrs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import trtrs [as 别名]
def test_trtrs(self):
        a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
                          (-6.05, -3.30, 5.36, -4.44, 1.08),
                          (-0.45, 2.58, -2.70, 0.27, 9.04),
                          (8.32, 2.71, 4.35, -7.17, 2.14),
                          (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
        b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
                          (-1.56, 4.00, -8.67, 1.75, 2.86),
                          (9.81, -4.09, -4.57, -8.61, 8.99))).t()

        U = torch.triu(a)
        L = torch.tril(a)

        # solve Ux = b
        x = torch.trtrs(b, U)[0]
        self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
        x = torch.trtrs(b, U, True, False, False)[0]
        self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)

        # solve Lx = b
        x = torch.trtrs(b, L, False)[0]
        self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
        x = torch.trtrs(b, L, False, False, False)[0]
        self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)

        # solve U'x = b
        x = torch.trtrs(b, U, True, True)[0]
        self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
        x = torch.trtrs(b, U, True, True, False)[0]
        self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)

        # solve U'x = b by manual transposition
        y = torch.trtrs(b, U.t(), False, False)[0]
        self.assertLessEqual(x.dist(y), 1e-12)

        # solve L'x = b
        x = torch.trtrs(b, L, False, True)[0]
        self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
        x = torch.trtrs(b, L, False, True, False)[0]
        self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)

        # solve L'x = b by manual transposition
        y = torch.trtrs(b, L.t(), True, False)[0]
        self.assertLessEqual(x.dist(y), 1e-12)

        # test reuse
        res1 = torch.trtrs(b, a)[0]
        ta = torch.Tensor()
        tb = torch.Tensor()
        torch.trtrs(b, a, out=(tb, ta))
        self.assertEqual(res1, tb, 0)
        tb.zero_()
        torch.trtrs(b, a, out=(tb, ta))
        self.assertEqual(res1, tb, 0) 
开发者ID:tylergenter,项目名称:pytorch,代码行数:56,代码来源:test_torch.py


注:本文中的torch.trtrs方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。