當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python pyspark CrossValidator用法及代碼示例

本文簡要介紹 pyspark.ml.tuning.CrossValidator 的用法。

用法:

class pyspark.ml.tuning.CrossValidator(*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, seed=None, parallelism=1, collectSubModels=False, foldCol='')

K-fold 交叉驗證通過將數據集拆分為一組不重疊的隨機分區折疊來執行模型選擇,這些折疊用作單獨的訓練和測試數據集,例如,當 k=3 折疊時,K-fold 交叉驗證將生成 3(訓練、 test) 數據集對,每個數據集使用 2/3 的數據進行訓練,使用 1/3 的數據進行測試。每個折疊僅用作測試集一次。

1.4.0 版中的新函數。

例子

>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, CrossValidatorModel
>>> import tempfile
>>> dataset = spark.createDataFrame(
...     [(Vectors.dense([0.0]), 0.0),
...      (Vectors.dense([0.4]), 1.0),
...      (Vectors.dense([0.5]), 0.0),
...      (Vectors.dense([0.6]), 1.0),
...      (Vectors.dense([1.0]), 1.0)] * 10,
...     ["features", "label"])
>>> lr = LogisticRegression()
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
...     parallelism=2)
>>> cvModel = cv.fit(dataset)
>>> cvModel.getNumFolds()
3
>>> cvModel.avgMetrics[0]
0.5
>>> path = tempfile.mkdtemp()
>>> model_path = path + "/model"
>>> cvModel.write().save(model_path)
>>> cvModelRead = CrossValidatorModel.read().load(model_path)
>>> cvModelRead.avgMetrics
[0.5, ...
>>> evaluator.evaluate(cvModel.transform(dataset))
0.8333...
>>> evaluator.evaluate(cvModelRead.transform(dataset))
0.8333...

相關用法


注:本文由純淨天空篩選整理自spark.apache.org大神的英文原創作品 pyspark.ml.tuning.CrossValidator。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。