当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python pyspark GaussianMixtureModel用法及代码示例


本文简要介绍 pyspark.mllib.clustering.GaussianMixtureModel 的用法。

用法:

class pyspark.mllib.clustering.GaussianMixtureModel(java_model)

源自高斯混合模型方法的聚类模型。

版本 1.3.0 中的新函数。

例子

>>> from pyspark.mllib.linalg import Vectors, DenseMatrix
>>> from numpy.testing import assert_equal
>>> from shutil import rmtree
>>> import os, tempfile
>>> clusterdata_1 =  sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
...                                         0.9,0.8,0.75,0.935,
...                                        -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2)
>>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
...                                 maxIterations=50, seed=10)
>>> labels = model.predict(clusterdata_1).collect()
>>> labels[0]==labels[1]
False
>>> labels[1]==labels[2]
False
>>> labels[4]==labels[5]
True
>>> model.predict([-0.1,-0.05])
0
>>> softPredicted = model.predictSoft([-0.1,-0.05])
>>> abs(softPredicted[0] - 1.0) < 0.03
True
>>> abs(softPredicted[1] - 0.0) < 0.03
True
>>> abs(softPredicted[2] - 0.0) < 0.03
True
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = GaussianMixtureModel.load(sc, path)
>>> assert_equal(model.weights, sameModel.weights)
>>> mus, sigmas = list(
...     zip(*[(g.mu, g.sigma) for g in model.gaussians]))
>>> sameMus, sameSigmas = list(
...     zip(*[(g.mu, g.sigma) for g in sameModel.gaussians]))
>>> mus == sameMus
True
>>> sigmas == sameSigmas
True
>>> from shutil import rmtree
>>> try:
...     rmtree(path)
... except OSError:
...     pass
>>> data =  array([-5.1971, -2.5359, -3.8220,
...                -5.2211, -5.0602,  4.7118,
...                 6.8989, 3.4592,  4.6322,
...                 5.7048,  4.6567, 5.5026,
...                 4.5605,  5.2043,  6.2734])
>>> clusterdata_2 = sc.parallelize(data.reshape(5,3))
>>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
...                               maxIterations=150, seed=4)
>>> labels = model.predict(clusterdata_2).collect()
>>> labels[0]==labels[1]
True
>>> labels[2]==labels[3]==labels[4]
True

相关用法


注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.mllib.clustering.GaussianMixtureModel。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。