当前位置: 首页>>代码示例>>Python>>正文


Python torch.potrs方法代码示例

本文整理汇总了Python中torch.potrs方法的典型用法代码示例。如果您正苦于以下问题:Python torch.potrs方法的具体用法?Python torch.potrs怎么用?Python torch.potrs使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.potrs方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_potrs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import potrs [as 别名]
def test_potrs(self):
        a=torch.Tensor(((6.80, -2.11,  5.66,  5.97,  8.23),
                        (-6.05, -3.30,  5.36, -4.44,  1.08),
                        (-0.45,  2.58, -2.70,  0.27,  9.04),
                        (8.32,  2.71,  4.35, -7.17,  2.14),
                        (-9.67, -5.14, -7.26,  6.08, -6.87))).t()
        b=torch.Tensor(((4.02,  6.19, -8.22, -7.57, -3.03),
                        (-1.56,  4.00, -8.67,  1.75,  2.86),
                        (9.81, -4.09, -4.57, -8.61,  8.99))).t()

        # make sure 'a' is symmetric PSD
        a = torch.mm(a, a.t())

        # upper Triangular Test
        U = torch.potrf(a)
        x = torch.potrs(b, U)
        self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)

        # lower Triangular Test
        L = torch.potrf(a, False)
        x = torch.potrs(b, L, False)
        self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12) 
开发者ID:apaszke,项目名称:pytorch-dist,代码行数:24,代码来源:test_torch.py

示例2: test_potrs

# 需要导入模块: import torch [as 别名]
# 或者: from torch import potrs [as 别名]
def test_potrs(self):
        a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
                          (-6.05, -3.30, 5.36, -4.44, 1.08),
                          (-0.45, 2.58, -2.70, 0.27, 9.04),
                          (8.32, 2.71, 4.35, -7.17, 2.14),
                          (-9.67, -5.14, -7.26, 6.08, -6.87))).t()
        b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),
                          (-1.56, 4.00, -8.67, 1.75, 2.86),
                          (9.81, -4.09, -4.57, -8.61, 8.99))).t()

        # make sure 'a' is symmetric PSD
        a = torch.mm(a, a.t())

        # upper Triangular Test
        U = torch.potrf(a)
        x = torch.potrs(b, U)
        self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)

        # lower Triangular Test
        L = torch.potrf(a, False)
        x = torch.potrs(b, L, False)
        self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12) 
开发者ID:tylergenter,项目名称:pytorch,代码行数:24,代码来源:test_torch.py

示例3: solve_kkt

# 需要导入模块: import torch [as 别名]
# 或者: from torch import potrs [as 别名]
def solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry, dbg=False):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, _ = get_sizes(G, A)

    invQ_rx = torch.potrs(rx.view(-1, 1), U_Q).view(-1)
    if neq > 0:
        h = torch.cat([torch.mv(A, invQ_rx) - ry,
                       torch.mv(G, invQ_rx) + rs / d - rz], 0)
    else:
        h = torch.mv(G, invQ_rx) + rs / d - rz

    w = -torch.potrs(h.view(-1, 1), U_S).view(-1)

    g1 = -rx - torch.mv(G.t(), w[neq:])
    if neq > 0:
        g1 -= torch.mv(A.t(), w[:neq])
    g2 = -rs - w[neq:]

    dx = torch.potrs(g1.view(-1, 1), U_Q).view(-1)
    ds = g2 / d
    dz = w[neq:]
    dy = w[:neq] if neq > 0 else None

    # if np.all(np.array([x.norm() for x in [rx, rs, rz, ry]]) != 0):
    if dbg:
        import IPython
        import sys
        IPython.embed()
        sys.exit(-1)

    # if rs.norm() > 0: import IPython, sys; IPython.embed(); sys.exit(-1)
    return dx, ds, dz, dy 
开发者ID:locuslab,项目名称:qpth,代码行数:34,代码来源:single.py

示例4: pre_factor_kkt

# 需要导入模块: import torch [as 别名]
# 或者: from torch import potrs [as 别名]
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, _ = get_sizes(G, A)

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T           ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]

    U_Q = torch.potrf(Q)
    # partial cholesky of S matrix
    U_S = torch.zeros(neq + nineq, neq + nineq).type_as(Q)

    G_invQ_GT = torch.mm(G, torch.potrs(G.t(), U_Q))
    R = G_invQ_GT
    if neq > 0:
        invQ_AT = torch.potrs(A.t(), U_Q)
        A_invQ_AT = torch.mm(A, invQ_AT)
        G_invQ_AT = torch.mm(G, invQ_AT)

        # TODO: torch.potrf sometimes says the matrix is not PSD but
        # numpy does? I filed an issue at
        # https://github.com/pytorch/pytorch/issues/199
        try:
            U11 = torch.potrf(A_invQ_AT)
        except:
            U11 = torch.Tensor(np.linalg.cholesky(
                A_invQ_AT.cpu().numpy())).type_as(A_invQ_AT)

        # TODO: torch.trtrs is currently not implemented on the GPU
        # and we are using gesv as a workaround.
        U12 = torch.gesv(G_invQ_AT.t(), U11.t())[0]
        U_S[:neq, :neq] = U11
        U_S[:neq, neq:] = U12
        R -= torch.mm(U12.t(), U12)

    return U_Q, U_S, R 
开发者ID:locuslab,项目名称:qpth,代码行数:37,代码来源:single.py

示例5: factor_solve_kkt

# 需要导入模块: import torch [as 别名]
# 或者: from torch import potrs [as 别名]
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, _ = get_sizes(G, A)

    if neq > 0:
        H_ = torch.cat([torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1),
                        torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q)], 1),
                        torch.cat([A, torch.zeros(neq, nineq).type_as(Q)], 1)], 0)
        g_ = torch.cat([rx, rs], 0)
        h_ = torch.cat([rz, ry], 0)
    else:
        H_ = torch.cat([torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1),
                        torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 0)
        h_ = rz

    U_H_ = torch.potrf(H_)

    invH_A_ = torch.potrs(A_.t(), U_H_)
    invH_g_ = torch.potrs(g_.view(-1, 1), U_H_).view(-1)

    S_ = torch.mm(A_, invH_A_)
    U_S_ = torch.potrf(S_)
    t_ = torch.mv(A_, invH_g_).view(-1, 1) - h_
    w_ = -torch.potrs(t_, U_S_).view(-1)
    v_ = torch.potrs(-g_.view(-1, 1) - torch.mv(A_.t(), w_), U_H_).view(-1)

    return v_[:nz], v_[nz:], w_[:nineq], w_[nineq:] if neq > 0 else None 
开发者ID:locuslab,项目名称:qpth,代码行数:31,代码来源:single.py


注:本文中的torch.potrs方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。