当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python pyspark NaiveBayesModel用法及代码示例


本文简要介绍 pyspark.mllib.classification.NaiveBayesModel 的用法。

用法:

class pyspark.mllib.classification.NaiveBayesModel(labels, pi, theta)

朴素贝叶斯分类器模型。

0.9.0 版中的新函数。

参数

labelsnumpy.ndarray

标签列表。

pinumpy.ndarray

类别先验的日志,维度为 C,标签数。

thetanumpy.ndarray

类条件概率的日志,其维度为C-by-D,其中 D 是特征数。

例子

>>> from pyspark.mllib.linalg import SparseVector
>>> data = [
...     LabeledPoint(0.0, [0.0, 0.0]),
...     LabeledPoint(0.0, [0.0, 1.0]),
...     LabeledPoint(1.0, [1.0, 0.0]),
... ]
>>> model = NaiveBayes.train(sc.parallelize(data))
>>> model.predict(numpy.array([0.0, 1.0]))
0.0
>>> model.predict(numpy.array([1.0, 0.0]))
1.0
>>> model.predict(sc.parallelize([[1.0, 0.0]])).collect()
[1.0]
>>> sparse_data = [
...     LabeledPoint(0.0, SparseVector(2, {1: 0.0})),
...     LabeledPoint(0.0, SparseVector(2, {1: 1.0})),
...     LabeledPoint(1.0, SparseVector(2, {0: 1.0}))
... ]
>>> model = NaiveBayes.train(sc.parallelize(sparse_data))
>>> model.predict(SparseVector(2, {1: 1.0}))
0.0
>>> model.predict(SparseVector(2, {0: 1.0}))
1.0
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = NaiveBayesModel.load(sc, path)
>>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0}))
True
>>> from shutil import rmtree
>>> try:
...     rmtree(path)
... except OSError:
...     pass

相关用法


注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.mllib.classification.NaiveBayesModel。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。