本文整理汇总了Python中torch.slogdet方法的典型用法代码示例。如果您正苦于以下问题:Python torch.slogdet方法的具体用法?Python torch.slogdet怎么用?Python torch.slogdet使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.slogdet方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input: Tensor
input tensor [batch, N1, N2, ..., Nl, in_features]
mask: Tensor
mask tensor [batch, N1, N2, ...,Nl]
Returns: out: Tensor , logdet: Tensor
out: [batch, N1, N2, ..., in_features], the output of the flow
logdet: [batch], the log determinant of :math:`\partial output / \partial input`
"""
dim = input.dim()
# [batch, N1, N2, ..., in_features]
out = F.linear(input, self.weight)
_, logdet = torch.slogdet(self.weight)
if dim > 2:
num = mask.view(out.size(0), -1).sum(dim=1)
logdet = logdet * num
return out, logdet
示例2: backward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input: Tensor
input tensor [batch, N1, N2, ..., Nl, in_features]
mask: Tensor
mask tensor [batch, N1, N2, ...,Nl]
Returns: out: Tensor , logdet: Tensor
out: [batch, N1, N2, ..., in_features], the output of the flow
logdet: [batch], the log determinant of :math:`\partial output / \partial input`
"""
dim = input.dim()
# [batch, N1, N2, ..., in_features]
out = F.linear(input, self.weight_inv)
_, logdet = torch.slogdet(self.weight_inv)
if dim > 2:
num = mask.view(out.size(0), -1).sum(dim=1)
logdet = logdet * num
return out, logdet
示例3: logabsdet
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def logabsdet(x):
"""Returns the log absolute determinant of square matrix x."""
# Note: torch.logdet() only works for positive determinant.
_, res = torch.slogdet(x)
return res
示例4: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, inputs, cond_inputs=None, mode='direct'):
if mode == 'direct':
return inputs @ self.W, torch.slogdet(
self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
inputs.size(0), 1)
else:
return inputs @ torch.inverse(self.W), -torch.slogdet(
self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
inputs.size(0), 1)
示例5: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, x):
# x --> z
# torch.slogdet() is not stable
if self.train_sampling:
W = torch.inverse(self.weight.double()).float()
else:
W = self.weight
logdet = self.log_determinant(x, W)
kernel = W.view(*self.w_shape, 1, 1)
return F.conv2d(x, kernel), logdet
示例6: backward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def backward(self, y: torch.tensor, x: torch.tensor=None, x_freqs: torch.tensor=None, require_log_probs=True, var=None, y_freqs=None):
# from other language to this language
x_prime = y.mm(self.W)
if require_log_probs:
assert x is not None, x_freqs is not None
log_probs = self.cal_mixture_of_gaussian_fix_var(x_prime, x, x_freqs, var, x_prime_freqs=y_freqs)
_, log_abs_det = torch.slogdet(self.W)
log_probs = log_probs + log_abs_det
else:
log_probs = torch.tensor(0)
return x_prime, log_probs
示例7: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, x, sldj, reverse=False):
ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)
if reverse:
weight = torch.inverse(self.weight.double()).float()
sldj = sldj - ldj
else:
weight = self.weight
sldj = sldj + ldj
weight = weight.view(self.num_channels, self.num_channels, 1, 1)
z = F.conv2d(x, weight)
return z, sldj
示例8: get_weight
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def get_weight(self, input, reverse):
w_shape = self.w_shape
if not self.LU:
pixels = thops.pixels(input)
dlogdet = torch.slogdet(self.weight)[1] * pixels
if not reverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
weight = torch.inverse(self.weight.double()).float()\
.view(w_shape[0], w_shape[1], 1, 1)
return weight, dlogdet
else:
self.p = self.p.to(input.device)
self.sign_s = self.sign_s.to(input.device)
self.l_mask = self.l_mask.to(input.device)
self.eye = self.eye.to(input.device)
l = self.l * self.l_mask + self.eye
u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
dlogdet = thops.sum(self.log_s) * thops.pixels(input)
if not reverse:
w = torch.matmul(self.p, torch.matmul(l, u))
else:
l = torch.inverse(l.double()).float()
u = torch.inverse(u.double()).float()
w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
return w.view(w_shape[0], w_shape[1], 1, 1), dlogdet
示例9: get_parameters
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def get_parameters(self, x, inverse):
w_shape = self.w_shape
pixels = np.prod(x.size()[2:])
device = x.device
if not self.decomposed:
logdet_jacobian = torch.slogdet(self.weight.cpu())[1].to(device) * pixels
if not inverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
weight = torch.inverse(self.weight.double()).float().view(w_shape[0], w_shape[1], 1, 1)
return weight, logdet_jacobian
else:
self.p = self.p.to(device)
self.sign_s = self.sign_s.to(device)
self.l_mask = self.l_mask.to(device)
self.eye = self.eye.to(device)
l = self.l * self.l_mask + self.eye
u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
logdet_jacobian = torch.sum(self.log_s) * pixels
if not inverse:
w = torch.matmul(self.p, torch.matmul(l, u))
else:
l = torch.inverse(l.double()).float()
u = torch.inverse(u.double()).float()
w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
return w.view(w_shape[0], w_shape[1], 1, 1), logdet_jacobian
示例10: forward
# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, input):
_, _, height, width = input.shape
out = F.conv2d(input, self.weight)
logdet = (
height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
)
return out, logdet