当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch MixtureSameFamily用法及代码示例


本文简要介绍python语言中 torch.distributions.mixture_same_family.MixtureSameFamily 的用法。

用法:

class torch.distributions.mixture_same_family.MixtureSameFamily(mixture_distribution, component_distribution, validate_args=None)

参数

  • mixture_distribution-torch.distributions.Categorical 类似实例。管理选择组件的概率。类别数必须与 component_distribution 的最右侧批次维度匹配。必须有标量 batch_shapebatch_shape 匹配 component_distribution.batch_shape[:-1]

  • component_distribution-torch.distributions.Distribution 类似实例。最右边的批量维度索引组件。

基础:torch.distributions.distribution.Distribution

MixtureSameFamily 分布实现了(一批)混合分布,其中所有分量都来自相同分布类型的不同参数化。它由 Categorical “selecting distribution”(在 k 组件之上)和组件分布参数化,即具有最右侧批次形状(等于 [k] )的 Distribution ,它索引每个(批次)组件。

例子:

# Construct Gaussian Mixture Model in 1D consisting of 5 equally
# weighted normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
>>> gmm = MixtureSameFamily(mix, comp)

# Construct Gaussian Mixture Modle in 2D consisting of 5 equally
# weighted bivariate normal distributions
>>> mix = D.Categorical(torch.ones(5,))
>>> comp = D.Independent(D.Normal(
             torch.randn(5,2), torch.rand(5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)

# Construct a batch of 3 Gaussian Mixture Models in 2D each
# consisting of 5 random weighted bivariate normal distributions
>>> mix = D.Categorical(torch.rand(3,5))
>>> comp = D.Independent(D.Normal(
            torch.randn(3,5,2), torch.rand(3,5,2)), 1)
>>> gmm = MixtureSameFamily(mix, comp)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributions.mixture_same_family.MixtureSameFamily。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。