當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.slogdet方法代碼示例

本文整理匯總了Python中torch.slogdet方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.slogdet方法的具體用法?Python torch.slogdet怎麽用?Python torch.slogdet使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.slogdet方法的10個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def forward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, N1, N2, ..., Nl, in_features]
            mask: Tensor
                mask tensor [batch, N1, N2, ...,Nl]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_features], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        dim = input.dim()
        # [batch, N1, N2, ..., in_features]
        out = F.linear(input, self.weight)
        _, logdet = torch.slogdet(self.weight)
        if dim > 2:
            num = mask.view(out.size(0), -1).sum(dim=1)
            logdet = logdet * num
        return out, logdet 
開發者ID:XuezheMax,項目名稱:flowseq,代碼行數:24,代碼來源:linear.py

示例2: backward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def backward(self, input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, N1, N2, ..., Nl, in_features]
            mask: Tensor
                mask tensor [batch, N1, N2, ...,Nl]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_features], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        dim = input.dim()
        # [batch, N1, N2, ..., in_features]
        out = F.linear(input, self.weight_inv)
        _, logdet = torch.slogdet(self.weight_inv)
        if dim > 2:
            num = mask.view(out.size(0), -1).sum(dim=1)
            logdet = logdet * num
        return out, logdet 
開發者ID:XuezheMax,項目名稱:flowseq,代碼行數:24,代碼來源:linear.py

示例3: logabsdet

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def logabsdet(x):
    """Returns the log absolute determinant of square matrix x."""
    # Note: torch.logdet() only works for positive determinant.
    _, res = torch.slogdet(x)
    return res 
開發者ID:bayesiains,項目名稱:nsf,代碼行數:7,代碼來源:torchutils.py

示例4: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def forward(self, inputs, cond_inputs=None, mode='direct'):
        if mode == 'direct':
            return inputs @ self.W, torch.slogdet(
                self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
                    inputs.size(0), 1)
        else:
            return inputs @ torch.inverse(self.W), -torch.slogdet(
                self.W)[-1].unsqueeze(0).unsqueeze(0).repeat(
                    inputs.size(0), 1) 
開發者ID:ikostrikov,項目名稱:pytorch-flows,代碼行數:11,代碼來源:flows.py

示例5: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def forward(self, x):
        # x --> z
        # torch.slogdet() is not stable
        if self.train_sampling:
            W = torch.inverse(self.weight.double()).float()  
        else:
            W = self.weight
        logdet = self.log_determinant(x, W)        
        kernel = W.view(*self.w_shape, 1, 1)
        return F.conv2d(x, kernel), logdet 
開發者ID:cics-nd,項目名稱:pde-surrogate,代碼行數:12,代碼來源:glow_msc.py

示例6: backward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def backward(self, y: torch.tensor, x: torch.tensor=None, x_freqs: torch.tensor=None, require_log_probs=True, var=None, y_freqs=None):
        # from other language to this language
        x_prime = y.mm(self.W)
        if require_log_probs:
            assert x is not None, x_freqs is not None
            log_probs = self.cal_mixture_of_gaussian_fix_var(x_prime, x, x_freqs, var, x_prime_freqs=y_freqs)
            _, log_abs_det = torch.slogdet(self.W)
            log_probs = log_probs + log_abs_det
        else:
            log_probs = torch.tensor(0)
        return x_prime, log_probs 
開發者ID:violet-zct,項目名稱:DeMa-BWE,代碼行數:13,代碼來源:mog_flow.py

示例7: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def forward(self, x, sldj, reverse=False):
        ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)

        if reverse:
            weight = torch.inverse(self.weight.double()).float()
            sldj = sldj - ldj
        else:
            weight = self.weight
            sldj = sldj + ldj

        weight = weight.view(self.num_channels, self.num_channels, 1, 1)
        z = F.conv2d(x, weight)

        return z, sldj 
開發者ID:chrischute,項目名稱:glow,代碼行數:16,代碼來源:inv_conv.py

示例8: get_weight

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def get_weight(self, input, reverse):
        w_shape = self.w_shape
        if not self.LU:
            pixels = thops.pixels(input)
            dlogdet = torch.slogdet(self.weight)[1] * pixels
            if not reverse:
                weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
            else:
                weight = torch.inverse(self.weight.double()).float()\
                              .view(w_shape[0], w_shape[1], 1, 1)
            return weight, dlogdet
        else:
            self.p = self.p.to(input.device)
            self.sign_s = self.sign_s.to(input.device)
            self.l_mask = self.l_mask.to(input.device)
            self.eye = self.eye.to(input.device)
            l = self.l * self.l_mask + self.eye
            u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
            dlogdet = thops.sum(self.log_s) * thops.pixels(input)
            if not reverse:
                w = torch.matmul(self.p, torch.matmul(l, u))
            else:
                l = torch.inverse(l.double()).float()
                u = torch.inverse(u.double()).float()
                w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
            return w.view(w_shape[0], w_shape[1], 1, 1), dlogdet 
開發者ID:chaiyujin,項目名稱:glow-pytorch,代碼行數:28,代碼來源:modules.py

示例9: get_parameters

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def get_parameters(self, x, inverse):
        w_shape = self.w_shape
        pixels = np.prod(x.size()[2:])
        device = x.device

        if not self.decomposed:
            logdet_jacobian = torch.slogdet(self.weight.cpu())[1].to(device) * pixels
            if not inverse:
                weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
            else:
                weight = torch.inverse(self.weight.double()).float().view(w_shape[0], w_shape[1], 1, 1)
            return weight, logdet_jacobian
        else:
            self.p = self.p.to(device)
            self.sign_s = self.sign_s.to(device)
            self.l_mask = self.l_mask.to(device)
            self.eye = self.eye.to(device)
            l = self.l * self.l_mask + self.eye
            u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
            logdet_jacobian = torch.sum(self.log_s) * pixels
            if not inverse:
                w = torch.matmul(self.p, torch.matmul(l, u))
            else:
                l = torch.inverse(l.double()).float()
                u = torch.inverse(u.double()).float()
                w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
            return w.view(w_shape[0], w_shape[1], 1, 1), logdet_jacobian 
開發者ID:masa-su,項目名稱:pixyz,代碼行數:29,代碼來源:conv.py

示例10: forward

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import slogdet [as 別名]
def forward(self, input):
        _, _, height, width = input.shape

        out = F.conv2d(input, self.weight)
        logdet = (
            height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
        )

        return out, logdet 
開發者ID:rosinality,項目名稱:glow-pytorch,代碼行數:11,代碼來源:model.py


注:本文中的torch.slogdet方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。