本文简要介绍
pyspark.mllib.clustering.GaussianMixtureModel
的用法。用法:
class pyspark.mllib.clustering.GaussianMixtureModel(java_model)
源自高斯混合模型方法的聚类模型。
版本 1.3.0 中的新函数。
例子:
>>> from pyspark.mllib.linalg import Vectors, DenseMatrix >>> from numpy.testing import assert_equal >>> from shutil import rmtree >>> import os, tempfile
>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2) >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001, ... maxIterations=50, seed=10) >>> labels = model.predict(clusterdata_1).collect() >>> labels[0]==labels[1] False >>> labels[1]==labels[2] False >>> labels[4]==labels[5] True >>> model.predict([-0.1,-0.05]) 0 >>> softPredicted = model.predictSoft([-0.1,-0.05]) >>> abs(softPredicted[0] - 1.0) < 0.03 True >>> abs(softPredicted[1] - 0.0) < 0.03 True >>> abs(softPredicted[2] - 0.0) < 0.03 True
>>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = GaussianMixtureModel.load(sc, path) >>> assert_equal(model.weights, sameModel.weights) >>> mus, sigmas = list( ... zip(*[(g.mu, g.sigma) for g in model.gaussians])) >>> sameMus, sameSigmas = list( ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians])) >>> mus == sameMus True >>> sigmas == sameSigmas True >>> from shutil import rmtree >>> try: ... rmtree(path) ... except OSError: ... pass
>>> data = array([-5.1971, -2.5359, -3.8220, ... -5.2211, -5.0602, 4.7118, ... 6.8989, 3.4592, 4.6322, ... 5.7048, 4.6567, 5.5026, ... 4.5605, 5.2043, 6.2734]) >>> clusterdata_2 = sc.parallelize(data.reshape(5,3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, ... maxIterations=150, seed=4) >>> labels = model.predict(clusterdata_2).collect() >>> labels[0]==labels[1] True >>> labels[2]==labels[3]==labels[4] True
相关用法
- Python pyspark GaussianMixture用法及代码示例
- Python pyspark GroupBy.mean用法及代码示例
- Python pyspark GroupBy.head用法及代码示例
- Python pyspark GroupedData.applyInPandas用法及代码示例
- Python pyspark GroupBy.cumsum用法及代码示例
- Python pyspark GroupBy.rank用法及代码示例
- Python pyspark GroupBy.bfill用法及代码示例
- Python pyspark GradientBoostedTrees.trainRegressor用法及代码示例
- Python pyspark GroupBy.cummin用法及代码示例
- Python pyspark GroupBy.cummax用法及代码示例
- Python pyspark GroupedData.mean用法及代码示例
- Python pyspark GroupBy.fillna用法及代码示例
- Python pyspark GroupBy.apply用法及代码示例
- Python pyspark GroupedData.agg用法及代码示例
- Python pyspark GroupedData.pivot用法及代码示例
- Python pyspark GroupBy.diff用法及代码示例
- Python pyspark GroupBy.filter用法及代码示例
- Python pyspark GroupBy.transform用法及代码示例
- Python pyspark GroupedData.apply用法及代码示例
- Python pyspark GroupBy.cumcount用法及代码示例
- Python pyspark GroupedData.max用法及代码示例
- Python pyspark GroupedData.count用法及代码示例
- Python pyspark GroupedData.min用法及代码示例
- Python pyspark GroupBy.idxmax用法及代码示例
- Python pyspark GroupBy.shift用法及代码示例
注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.mllib.clustering.GaussianMixtureModel。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。