当前位置: 首页>>代码示例>>Python>>正文


Python torch.pinverse方法代码示例

本文整理汇总了Python中torch.pinverse方法的典型用法代码示例。如果您正苦于以下问题:Python torch.pinverse方法的具体用法?Python torch.pinverse怎么用?Python torch.pinverse使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.pinverse方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_inverse_filters

# 需要导入模块: import torch [as 别名]
# 或者: from torch import pinverse [as 别名]
def get_inverse_filters(self):
        fourier_basis = self._get_fft_basis()
        inverse_filters = torch.pinverse(
            fourier_basis.unsqueeze(0)).squeeze(0)
        return nn.Parameter(inverse_filters, requires_grad=self.requires_grad) 
开发者ID:nussl,项目名称:nussl,代码行数:7,代码来源:filter_bank.py

示例2: _evaluate_sample

# 需要导入模块: import torch [as 别名]
# 或者: from torch import pinverse [as 别名]
def _evaluate_sample(self, sample: LogSample) -> Optional[EstimatorSampleResult]:
        log_slot_expects = sample.log_slot_item_expectations(sample.context.slots)
        if log_slot_expects is None:
            logger.warning("Log slot distribution not available")
            return None
        tgt_slot_expects = sample.tgt_slot_expectations(sample.context.slots)
        if tgt_slot_expects is None:
            logger.warning("Target slot distribution not available")
            return None
        log_indicator = log_slot_expects.values_tensor(self._device)
        tgt_indicator = tgt_slot_expects.values_tensor(self._device)
        lm = len(sample.context.slots) * len(sample.items)
        gamma = torch.as_tensor(
            np.linalg.pinv(
                torch.mm(
                    log_indicator.view((lm, 1)), log_indicator.view((1, lm))
                ).numpy()
            )
        )
        # torch.pinverse is not very stable
        # gamma = torch.pinverse(
        #     torch.mm(log_indicator.view((lm, 1)), log_indicator.view((1, lm)))
        # )
        ones = sample.log_slate.one_hots(sample.items, self._device)
        weight = self._weight_clamper(
            torch.mm(tgt_indicator.view((1, lm)), torch.mm(gamma, ones.view((lm, 1))))
        ).item()
        return EstimatorSampleResult(
            sample.log_reward,
            sample.log_reward * weight,
            sample.ground_truth_reward,
            weight,
        )

    # pyre-fixme[14]: `evaluate` overrides method defined in `Estimator` inconsistently. 
开发者ID:facebookresearch,项目名称:ReAgent,代码行数:37,代码来源:slate_estimators.py

示例3: fit

# 需要导入模块: import torch [as 别名]
# 或者: from torch import pinverse [as 别名]
def fit(self):
        if self.readout_training in {'gd', 'svd'}:
            return

        if self.readout_training == 'cholesky':
            W = torch.solve(self.XTy,
                           self.XTX + self.lambda_reg * torch.eye(
                               self.XTX.size(0), device=self.XTX.device))[0].t()
            self.XTX = None
            self.XTy = None

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])
        elif self.readout_training == 'inv':
            I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to(
                self.XTX.device)
            A = self.XTX + I

            if torch.det(A) != 0:
                W = torch.mm(torch.inverse(A), self.XTy).t()
            else:
                pinv = torch.pinverse(A)
                W = torch.mm(pinv, self.XTy).t()

            self.readout.bias = nn.Parameter(W[:, 0])
            self.readout.weight = nn.Parameter(W[:, 1:])

            self.XTX = None
            self.XTy = None 
开发者ID:stefanonardo,项目名称:pytorch-esn,代码行数:31,代码来源:echo_state_network.py

示例4: compute_filter_pinv

# 需要导入模块: import torch [as 别名]
# 或者: from torch import pinverse [as 别名]
def compute_filter_pinv(self, filters):
        """ Computes pseudo inverse filterbank of given filters."""
        scale = self.filterbank.stride / self.filterbank.kernel_size
        shape = filters.shape
        ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
        # Compensate for the overlap-add.
        return ifilt * scale 
开发者ID:mpariente,项目名称:asteroid,代码行数:9,代码来源:enc_dec.py


注:本文中的torch.pinverse方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。