当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python pyspark TrainValidationSplit用法及代码示例


本文简要介绍 pyspark.ml.tuning.TrainValidationSplit 的用法。

用法:

class pyspark.ml.tuning.TrainValidationSplit(*, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None)

验证hyper-parameter 调整。将输入数据集随机拆分为训练集和验证集,并使用验证集上的评估指标来选择最佳模型。类似于 CrossValidator ,但只拆分集合一次。

2.0.0 版中的新函数。

例子

>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder
>>> from pyspark.ml.tuning import TrainValidationSplitModel
>>> import tempfile
>>> dataset = spark.createDataFrame(
...     [(Vectors.dense([0.0]), 0.0),
...      (Vectors.dense([0.4]), 1.0),
...      (Vectors.dense([0.5]), 0.0),
...      (Vectors.dense([0.6]), 1.0),
...      (Vectors.dense([1.0]), 1.0)] * 10,
...     ["features", "label"]).repartition(1)
>>> lr = LogisticRegression()
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
...     parallelism=1, seed=42)
>>> tvsModel = tvs.fit(dataset)
>>> tvsModel.getTrainRatio()
0.75
>>> tvsModel.validationMetrics
[0.5, ...
>>> path = tempfile.mkdtemp()
>>> model_path = path + "/model"
>>> tvsModel.write().save(model_path)
>>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)
>>> tvsModelRead.validationMetrics
[0.5, ...
>>> evaluator.evaluate(tvsModel.transform(dataset))
0.833...
>>> evaluator.evaluate(tvsModelRead.transform(dataset))
0.833...

相关用法


注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.ml.tuning.TrainValidationSplit。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。