本文整理汇总了Python中torch.btriunpack方法的典型用法代码示例。如果您正苦于以下问题:Python torch.btriunpack方法的具体用法?Python torch.btriunpack怎么用?Python torch.btriunpack使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.btriunpack方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _test_btrifact
# 需要导入模块: import torch [as 别名]
# 或者: from torch import btriunpack [as 别名]
def _test_btrifact(self, cast):
a = torch.FloatTensor((((1.3722, -0.9020),
(1.8849, 1.9169)),
((0.7187, -1.1695),
(-0.0139, 1.3572)),
((-1.6181, 0.7148),
(1.3728, 0.1319))))
a = cast(a)
info = cast(torch.IntTensor())
a_LU = a.btrifact(info=info)
self.assertEqual(info.abs().sum(), 0)
P, a_L, a_U = torch.btriunpack(*a_LU)
a_ = torch.bmm(P, torch.bmm(a_L, a_U))
self.assertEqual(a_, a)
示例2: factor_kkt
# 需要导入模块: import torch [as 别名]
# 或者: from torch import btriunpack [as 别名]
def factor_kkt(S_LU, R, d):
""" Factor the U22 block that we can only do after we know D. """
nBatch, nineq = d.size()
neq = S_LU[1].size(1) - nineq
# TODO: There's probably a better way to add a batched diagonal.
global factor_kkt_eye
if factor_kkt_eye is None or factor_kkt_eye.size() != d.size():
# print('Updating batchedEye size.')
factor_kkt_eye = torch.eye(nineq).repeat(
nBatch, 1, 1).type_as(R).byte()
T = R.clone()
T[factor_kkt_eye] += (1. / d).squeeze()
T_LU = btrifact_hack(T)
global shown_btrifact_warning
if shown_btrifact_warning or not T.is_cuda:
# TODO: Don't use pivoting in most cases because
# torch.btriunpack is inefficient here:
oldPivotsPacked = S_LU[1][:, -nineq:] - neq
oldPivots, _, _ = torch.btriunpack(
T_LU[0], oldPivotsPacked, unpack_data=False)
newPivotsPacked = T_LU[1]
newPivots, _, _ = torch.btriunpack(
T_LU[0], newPivotsPacked, unpack_data=False)
# Re-pivot the S_LU_21 block.
if neq > 0:
S_LU_21 = S_LU[0][:, -nineq:, :neq]
S_LU[0][:, -nineq:,
:neq] = newPivots.transpose(1, 2).bmm(oldPivots.bmm(S_LU_21))
# Add the new S_LU_22 block pivots.
S_LU[1][:, -nineq:] = newPivotsPacked + neq
# Add the new S_LU_22 block.
S_LU[0][:, -nineq:, -nineq:] = T_LU[0]
示例3: pre_factor_kkt
# 需要导入模块: import torch [as 别名]
# 或者: from torch import btriunpack [as 别名]
def pre_factor_kkt(Q, G, A):
""" Perform all one-time factorizations and cache relevant matrix products"""
nineq, nz, neq, nBatch = get_sizes(G, A)
try:
Q_LU = btrifact_hack(Q)
except:
raise RuntimeError("""
qpth Error: Cannot perform LU factorization on Q.
Please make sure that your Q matrix is PSD and has
a non-zero diagonal.
""")
# S = [ A Q^{-1} A^T A Q^{-1} G^T ]
# [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ]
#
# We compute a partial LU decomposition of the S matrix
# that can be completed once D^{-1} is known.
# See the 'Block LU factorization' part of our website
# for more details.
G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrisolve(*Q_LU))
R = G_invQ_GT.clone()
S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
.repeat(nBatch, 1).type_as(Q).int()
if neq > 0:
invQ_AT = A.transpose(1, 2).btrisolve(*Q_LU)
A_invQ_AT = torch.bmm(A, invQ_AT)
G_invQ_AT = torch.bmm(G, invQ_AT)
LU_A_invQ_AT = btrifact_hack(A_invQ_AT)
P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.btriunpack(*LU_A_invQ_AT)
P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)
S_LU_11 = LU_A_invQ_AT[0]
U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)
).btrisolve(*LU_A_invQ_AT)
S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
T = G_invQ_AT.transpose(1, 2).btrisolve(*LU_A_invQ_AT)
S_LU_12 = U_A_invQ_AT.bmm(T)
S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
S_LU_data = torch.cat((torch.cat((S_LU_11, S_LU_12), 2),
torch.cat((S_LU_21, S_LU_22), 2)),
1)
S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]
R -= G_invQ_AT.bmm(T)
else:
S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)
S_LU = [S_LU_data, S_LU_pivots]
return Q_LU, S_LU, R