本文整理汇总了Python中torch.pinverse方法的典型用法代码示例。如果您正苦于以下问题:Python torch.pinverse方法的具体用法?Python torch.pinverse怎么用?Python torch.pinverse使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.pinverse方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_inverse_filters
# 需要导入模块: import torch [as 别名]
# 或者: from torch import pinverse [as 别名]
def get_inverse_filters(self):
fourier_basis = self._get_fft_basis()
inverse_filters = torch.pinverse(
fourier_basis.unsqueeze(0)).squeeze(0)
return nn.Parameter(inverse_filters, requires_grad=self.requires_grad)
示例2: _evaluate_sample
# 需要导入模块: import torch [as 别名]
# 或者: from torch import pinverse [as 别名]
def _evaluate_sample(self, sample: LogSample) -> Optional[EstimatorSampleResult]:
log_slot_expects = sample.log_slot_item_expectations(sample.context.slots)
if log_slot_expects is None:
logger.warning("Log slot distribution not available")
return None
tgt_slot_expects = sample.tgt_slot_expectations(sample.context.slots)
if tgt_slot_expects is None:
logger.warning("Target slot distribution not available")
return None
log_indicator = log_slot_expects.values_tensor(self._device)
tgt_indicator = tgt_slot_expects.values_tensor(self._device)
lm = len(sample.context.slots) * len(sample.items)
gamma = torch.as_tensor(
np.linalg.pinv(
torch.mm(
log_indicator.view((lm, 1)), log_indicator.view((1, lm))
).numpy()
)
)
# torch.pinverse is not very stable
# gamma = torch.pinverse(
# torch.mm(log_indicator.view((lm, 1)), log_indicator.view((1, lm)))
# )
ones = sample.log_slate.one_hots(sample.items, self._device)
weight = self._weight_clamper(
torch.mm(tgt_indicator.view((1, lm)), torch.mm(gamma, ones.view((lm, 1))))
).item()
return EstimatorSampleResult(
sample.log_reward,
sample.log_reward * weight,
sample.ground_truth_reward,
weight,
)
# pyre-fixme[14]: `evaluate` overrides method defined in `Estimator` inconsistently.
示例3: fit
# 需要导入模块: import torch [as 别名]
# 或者: from torch import pinverse [as 别名]
def fit(self):
if self.readout_training in {'gd', 'svd'}:
return
if self.readout_training == 'cholesky':
W = torch.solve(self.XTy,
self.XTX + self.lambda_reg * torch.eye(
self.XTX.size(0), device=self.XTX.device))[0].t()
self.XTX = None
self.XTy = None
self.readout.bias = nn.Parameter(W[:, 0])
self.readout.weight = nn.Parameter(W[:, 1:])
elif self.readout_training == 'inv':
I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to(
self.XTX.device)
A = self.XTX + I
if torch.det(A) != 0:
W = torch.mm(torch.inverse(A), self.XTy).t()
else:
pinv = torch.pinverse(A)
W = torch.mm(pinv, self.XTy).t()
self.readout.bias = nn.Parameter(W[:, 0])
self.readout.weight = nn.Parameter(W[:, 1:])
self.XTX = None
self.XTy = None
示例4: compute_filter_pinv
# 需要导入模块: import torch [as 别名]
# 或者: from torch import pinverse [as 别名]
def compute_filter_pinv(self, filters):
""" Computes pseudo inverse filterbank of given filters."""
scale = self.filterbank.stride / self.filterbank.kernel_size
shape = filters.shape
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
# Compensate for the overlap-add.
return ifilt * scale