当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch MultivariateNormal用法及代码示例


本文简要介绍python语言中 torch.distributions.multivariate_normal.MultivariateNormal 的用法。

用法:

class torch.distributions.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)

参数

  • loc(Tensor) -分布的平均值

  • covariance_matrix(Tensor) -正定协方差矩阵

  • precision_matrix(Tensor) -正定精度矩阵

  • scale_tril(Tensor) -协方差的下三角因子,positive-valued 对角线

基础:torch.distributions.distribution.Distribution

创建由均值向量和协方差矩阵参数化的多元正态(也称为高斯)分布。

多元正态分布可以根据正定协方差矩阵 或正定精度矩阵 或具有positive-valued对角线条目的下三角矩阵 进行参数化,例如 。这个三角矩阵可以通过例如获得协方差的 Cholesky 分解。

示例

>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> m.sample()  # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
tensor([-0.2102, -0.5429])

注意

只能指定covariance_matrixprecision_matrixscale_tril 之一。

使用 scale_tril 会更高效:内部所有计算都基于 scale_tril 。如果改为传递 covariance_matrixprecision_matrix,则它仅用于使用 Cholesky 分解计算相应的下三角矩阵。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributions.multivariate_normal.MultivariateNormal。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。