當前位置: 首頁>>代碼示例>>Python>>正文


Python torch.pinverse方法代碼示例

本文整理匯總了Python中torch.pinverse方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.pinverse方法的具體用法?Python torch.pinverse怎麽用?Python torch.pinverse使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch的用法示例。


在下文中一共展示了torch.pinverse方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_inverse_filters

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import pinverse [as 別名]
def get_inverse_filters(self):
        fourier_basis = self._get_fft_basis()
        inverse_filters = torch.pinverse(
            fourier_basis.unsqueeze(0)).squeeze(0)
        return nn.Parameter(inverse_filters, requires_grad=self.requires_grad) 
開發者ID:nussl,項目名稱:nussl,代碼行數:7,代碼來源:filter_bank.py

示例2: _evaluate_sample

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import pinverse [as 別名]
def _evaluate_sample(self, sample: LogSample) -> Optional[EstimatorSampleResult]:
        log_slot_expects = sample.log_slot_item_expectations(sample.context.slots)
        if log_slot_expects is None:
            logger.warning("Log slot distribution not available")
            return None
        tgt_slot_expects = sample.tgt_slot_expectations(sample.context.slots)
        if tgt_slot_expects is None:
            logger.warning("Target slot distribution not available")
            return None
        log_indicator = log_slot_expects.values_tensor(self._device)
        tgt_indicator = tgt_slot_expects.values_tensor(self._device)
        lm = len(sample.context.slots) * len(sample.items)
        gamma = torch.as_tensor(
            np.linalg.pinv(
                torch.mm(
                    log_indicator.view((lm, 1)), log_indicator.view((1, lm))
                ).numpy()
            )
        )
        # torch.pinverse is not very stable
        # gamma = torch.pinverse(
        #     torch.mm(log_indicator.view((lm, 1)), log_indicator.view((1, lm)))
        # )
        ones = sample.log_slate.one_hots(sample.items, self._device)
        weight = self._weight_clamper(
            torch.mm(tgt_indicator.view((1, lm)), torch.mm(gamma, ones.view((lm, 1))))
        ).item()
        return EstimatorSampleResult(
            sample.log_reward,
            sample.log_reward * weight,
            sample.ground_truth_reward,
            weight,
        )

    # pyre-fixme[14]: `evaluate` overrides method defined in `Estimator` inconsistently. 
開發者ID:facebookresearch,項目名稱:ReAgent,代碼行數:37,代碼來源:slate_estimators.py

示例3: fit

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import pinverse [as 別名]
def fit(self):
        if self.readout_training in {'gd', 'svd'}:
            return

        if self.readout_training == 'cholesky':
            W = torch.solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device))[0].t()
            self.XTX = None
            self.XTy = None

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])
        elif self.readout_training == 'inv':
            I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to(
                self.XTX.device)
            A = self.XTX + I

            if torch.det(A) != 0:
                W = torch.mm(torch.inverse(A), self.XTy).t()
            else:
                pinv = torch.pinverse(A)
                W = torch.mm(pinv, self.XTy).t()

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])

            self.XTX = None
            self.XTy = None 
開發者ID:stefanonardo,項目名稱:pytorch-esn,代碼行數:31,代碼來源:echo_state_network.py

示例4: compute_filter_pinv

# 需要導入模塊: import torch [as 別名]
# 或者: from torch import pinverse [as 別名]
def compute_filter_pinv(self, filters):
        """ Computes pseudo inverse filterbank of given filters."""
        scale = self.filterbank.stride / self.filterbank.kernel_size
        shape = filters.shape
        ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
        # Compensate for the overlap-add.
        return ifilt * scale 
開發者ID:mpariente,項目名稱:asteroid,代碼行數:9,代碼來源:enc_dec.py


注:本文中的torch.pinverse方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。