當前位置: 首頁>>機器學習>>正文


pyspark LogisticRegressionModel用法示例

類用法簡介

class pyspark.mllib.classification.LogisticRegressionModel(weights, intercept, numFeatures, numClasses)

LogisticRegressionModel: 使用多元/二元邏輯回歸訓練的分類模型。

參數說明

  • weights – 每個特征的權重。
  • intercept – 為此模型計算的截距。 (僅用於二元邏輯回歸,在多項Logistic回歸中,截距不會是單一值,所以截距將是權重的一部分。)
  • numFeatures – 特征的維度。
  • numClasses – 多項Logistic回歸中k類分類問題的可能結果的數量。默認情況下,它是二元logistic回歸,所以numClasses將被設置為2。

示例一

###訓練樣本數據,注意LabelPoint中的Label是用浮點數表示的
>>> data = [
...     LabeledPoint(0.0, [0.0, 1.0]),
...     LabeledPoint(1.0, [1.0, 0.0]),
... ]

>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data), iterations=10)
>>> lrm.predict([1.0, 0.0])
1
>>> lrm.predict([0.0, 1.0])
0
>>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect()
[1, 0]
### 清空閾值,輸出概率值
>>> lrm.clearThreshold()
>>> lrm.predict([0.0, 1.0])
0.279...

示例二


>>> sparse_data = [
...     LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
...     LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
...     LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
...     LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
### 模型訓練,使用SGD
>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data), iterations=10)
>>> lrm.predict(array([0.0, 1.0]))
1
>>> lrm.predict(array([1.0, 0.0]))
0
>>> lrm.predict(SparseVector(2, {1: 1.0}))
1
>>> lrm.predict(SparseVector(2, {0: 1.0}))
0
### 保存模型文件, 加載模型文件
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> lrm.save(sc, path)
>>> sameModel = LogisticRegressionModel.load(sc, path)
>>> sameModel.predict(array([0.0, 1.0]))
1
>>> sameModel.predict(SparseVector(2, {0: 1.0}))
0
### 清除臨時文件(rmtree可以刪除文件夾及文件夾下的文件)
>>> from shutil import rmtree
>>> try:
...    rmtree(path)
... except:
...    pass
### 多分類模型訓練和預測
>>> multi_class_data = [
...     LabeledPoint(0.0, [0.0, 1.0, 0.0]),
...     LabeledPoint(1.0, [1.0, 0.0, 0.0]),
...     LabeledPoint(2.0, [0.0, 0.0, 1.0])
... ]
>>> data = sc.parallelize(multi_class_data)
### 模型訓練,使用LBFGS
>>> mcm = LogisticRegressionWithLBFGS.train(data, iterations=10, numClasses=3)
>>> mcm.predict([0.0, 0.5, 0.0])
0
>>> mcm.predict([0.8, 0.0, 0.0])
1
>>> mcm.predict([0.0, 0.0, 0.3])
2

pyspark中LogisticRegressionModel類的更多介紹詳見:LogisticRegressionModel

本文由《純淨天空》出品。文章地址: https://vimsky.com/zh-tw/article/3338.html,未經允許,請勿轉載。