當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch LowRankMultivariateNormal用法及代碼示例


本文簡要介紹python語言中 torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal 的用法。

用法:

class torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)

參數

  • loc(Tensor) -形狀為 batch_shape + event_shape 的分布的平均值

  • cov_factor(Tensor) -形狀為 batch_shape + event_shape + (rank,) 的協方差矩陣的低階形式的因子部分

  • cov_diag(Tensor) -形狀為 batch_shape + event_shape 的協方差矩陣低秩形式的對角部分

基礎:torch.distributions.distribution.Distribution

創建一個多元正態分布,其協方差矩陣具有由 cov_factorcov_diag 參數化的低秩形式:

covariance_matrix = cov_factor @ cov_factor.T + cov_diag

示例

>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
>>> m.sample()  # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
tensor([-0.2102, -0.5429])

注意

由於 Woodbury matrix identitymatrix determinant lemma ,在 cov_factor.shape[1] << cov_factor.shape[0] 時避免了協方差矩陣的行列式和逆矩陣的計算。由於這些公式,我們隻需要計算小尺寸“capacitance” 矩陣的行列式和逆矩陣:

capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。