当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python pyspark CrossValidator用法及代码示例


本文简要介绍 pyspark.ml.tuning.CrossValidator 的用法。

用法:

class pyspark.ml.tuning.CrossValidator(*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, seed=None, parallelism=1, collectSubModels=False, foldCol='')

K-fold 交叉验证通过将数据集拆分为一组不重叠的随机分区折叠来执行模型选择,这些折叠用作单独的训练和测试数据集,例如,当 k=3 折叠时,K-fold 交叉验证将生成 3(训练、 test) 数据集对,每个数据集使用 2/3 的数据进行训练,使用 1/3 的数据进行测试。每个折叠仅用作测试集一次。

1.4.0 版中的新函数。

例子

>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, CrossValidatorModel
>>> import tempfile
>>> dataset = spark.createDataFrame(
...     [(Vectors.dense([0.0]), 0.0),
...      (Vectors.dense([0.4]), 1.0),
...      (Vectors.dense([0.5]), 0.0),
...      (Vectors.dense([0.6]), 1.0),
...      (Vectors.dense([1.0]), 1.0)] * 10,
...     ["features", "label"])
>>> lr = LogisticRegression()
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
...     parallelism=2)
>>> cvModel = cv.fit(dataset)
>>> cvModel.getNumFolds()
3
>>> cvModel.avgMetrics[0]
0.5
>>> path = tempfile.mkdtemp()
>>> model_path = path + "/model"
>>> cvModel.write().save(model_path)
>>> cvModelRead = CrossValidatorModel.read().load(model_path)
>>> cvModelRead.avgMetrics
[0.5, ...
>>> evaluator.evaluate(cvModel.transform(dataset))
0.8333...
>>> evaluator.evaluate(cvModelRead.transform(dataset))
0.8333...

相关用法


注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.ml.tuning.CrossValidator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。