当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch LowRankMultivariateNormal用法及代码示例


本文简要介绍python语言中 torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal 的用法。

用法:

class torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)

参数

  • loc(Tensor) -形状为 batch_shape + event_shape 的分布的平均值

  • cov_factor(Tensor) -形状为 batch_shape + event_shape + (rank,) 的协方差矩阵的低阶形式的因子部分

  • cov_diag(Tensor) -形状为 batch_shape + event_shape 的协方差矩阵低秩形式的对角部分

基础:torch.distributions.distribution.Distribution

创建一个多元正态分布,其协方差矩阵具有由 cov_factorcov_diag 参数化的低秩形式:

covariance_matrix = cov_factor @ cov_factor.T + cov_diag

示例

>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
>>> m.sample()  # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
tensor([-0.2102, -0.5429])

注意

由于 Woodbury matrix identitymatrix determinant lemma ,在 cov_factor.shape[1] << cov_factor.shape[0] 时避免了协方差矩阵的行列式和逆矩阵的计算。由于这些公式,我们只需要计算小尺寸“capacitance” 矩阵的行列式和逆矩阵:

capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。