当前位置: 首页>>代码示例>>Python>>正文


Python torch.slogdet方法代码示例

本文整理汇总了Python中torch.slogdet方法的典型用法代码示例。如果您正苦于以下问题:Python torch.slogdet方法的具体用法?Python torch.slogdet怎么用?Python torch.slogdet使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.slogdet方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, N1, N2, ..., Nl, in_features]
            mask: Tensor
                mask tensor [batch, N1, N2, ...,Nl]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_features], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        dim = input.dim()
        # [batch, N1, N2, ..., in_features]
        out = F.linear(input, self.weight)
        _, logdet = torch.slogdet(self.weight)
        if dim > 2:
            num = mask.view(out.size(0), -1).sum(dim=1)
            logdet = logdet * num
        return out, logdet 
开发者ID:XuezheMax,项目名称:flowseq,代码行数:24,代码来源:linear.py

示例2: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, N1, N2, ..., Nl, in_features]
            mask: Tensor
                mask tensor [batch, N1, N2, ...,Nl]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_features], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        dim = input.dim()
        # [batch, N1, N2, ..., in_features]
        out = F.linear(input, self.weight_inv)
        _, logdet = torch.slogdet(self.weight_inv)
        if dim > 2:
            num = mask.view(out.size(0), -1).sum(dim=1)
            logdet = logdet * num
        return out, logdet 
开发者ID:XuezheMax,项目名称:flowseq,代码行数:24,代码来源:linear.py

示例3: logabsdet

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def logabsdet(x):
    """Returns the log absolute determinant of square matrix x."""
    # Note: torch.logdet() only works for positive determinant.
    _, res = torch.slogdet(x)
    return res 
开发者ID:bayesiains,项目名称:nsf,代码行数:7,代码来源:torchutils.py

示例4: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            return inputs @ self.W, torch.slogdet(
                self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
                    inputs.size(0), 1)
        else:
            return inputs @ torch.inverse(self.W), -torch.slogdet(
                self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
                    inputs.size(0), 1) 
开发者ID:ikostrikov,项目名称:pytorch-flows,代码行数:11,代码来源:flows.py

示例5: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, x):
        # x --> z
        # torch.slogdet() is not stable
        if self.train_sampling:
            W = torch.inverse(self.weight.double()).float()  
        else:
            W = self.weight
        logdet = self.log_determinant(x, W)        
        kernel = W.view(*self.w_shape, 1, 1)
        return F.conv2d(x, kernel), logdet 
开发者ID:cics-nd,项目名称:pde-surrogate,代码行数:12,代码来源:glow_msc.py

示例6: backward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def backward(self, y: torch.tensor, x: torch.tensor=None, x_freqs: torch.tensor=None, require_log_probs=True, var=None, y_freqs=None):
        # from other language to this language
        x_prime = y.mm(self.W)
        if require_log_probs:
            assert x is not None, x_freqs is not None
            log_probs = self.cal_mixture_of_gaussian_fix_var(x_prime, x, x_freqs, var, x_prime_freqs=y_freqs)
            _, log_abs_det = torch.slogdet(self.W)
            log_probs = log_probs + log_abs_det
        else:
            log_probs = torch.tensor(0)
        return x_prime, log_probs 
开发者ID:violet-zct,项目名称:DeMa-BWE,代码行数:13,代码来源:mog_flow.py

示例7: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, x, sldj, reverse=False):
        ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)

        if reverse:
            weight = torch.inverse(self.weight.double()).float()
            sldj = sldj - ldj
        else:
            weight = self.weight
            sldj = sldj + ldj

        weight = weight.view(self.num_channels, self.num_channels, 1, 1)
        z = F.conv2d(x, weight)

        return z, sldj 
开发者ID:chrischute,项目名称:glow,代码行数:16,代码来源:inv_conv.py

示例8: get_weight

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def get_weight(self, input, reverse):
        w_shape = self.w_shape
        if not self.LU:
            pixels = thops.pixels(input)
            dlogdet = torch.slogdet(self.weight)[1] * pixels
            if not reverse:
                weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
            else:
                weight = torch.inverse(self.weight.double()).float()\
                              .view(w_shape[0], w_shape[1], 1, 1)
            return weight, dlogdet
        else:
            self.p = self.p.to(input.device)
            self.sign_s = self.sign_s.to(input.device)
            self.l_mask = self.l_mask.to(input.device)
            self.eye = self.eye.to(input.device)
            l = self.l * self.l_mask + self.eye
            u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
            dlogdet = thops.sum(self.log_s) * thops.pixels(input)
            if not reverse:
                w = torch.matmul(self.p, torch.matmul(l, u))
            else:
                l = torch.inverse(l.double()).float()
                u = torch.inverse(u.double()).float()
                w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
            return w.view(w_shape[0], w_shape[1], 1, 1), dlogdet 
开发者ID:chaiyujin,项目名称:glow-pytorch,代码行数:28,代码来源:modules.py

示例9: get_parameters

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def get_parameters(self, x, inverse):
        w_shape = self.w_shape
        pixels = np.prod(x.size()[2:])
        device = x.device

        if not self.decomposed:
            logdet_jacobian = torch.slogdet(self.weight.cpu())[1].to(device) * pixels
            if not inverse:
                weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
            else:
                weight = torch.inverse(self.weight.double()).float().view(w_shape[0], w_shape[1], 1, 1)
            return weight, logdet_jacobian
        else:
            self.p = self.p.to(device)
            self.sign_s = self.sign_s.to(device)
            self.l_mask = self.l_mask.to(device)
            self.eye = self.eye.to(device)
            l = self.l * self.l_mask + self.eye
            u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
            logdet_jacobian = torch.sum(self.log_s) * pixels
            if not inverse:
                w = torch.matmul(self.p, torch.matmul(l, u))
            else:
                l = torch.inverse(l.double()).float()
                u = torch.inverse(u.double()).float()
                w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
            return w.view(w_shape[0], w_shape[1], 1, 1), logdet_jacobian 
开发者ID:masa-su,项目名称:pixyz,代码行数:29,代码来源:conv.py

示例10: forward

# 需要导入模块: import torch [as 别名]
# 或者: from torch import slogdet [as 别名]
def forward(self, input):
        _, _, height, width = input.shape

        out = F.conv2d(input, self.weight)
        logdet = (
            height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
        )

        return out, logdet 
开发者ID:rosinality,项目名称:glow-pytorch,代码行数:11,代码来源:model.py


注:本文中的torch.slogdet方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。