當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python pyspark TrainValidationSplit用法及代碼示例

本文簡要介紹 pyspark.ml.tuning.TrainValidationSplit 的用法。

用法:

class pyspark.ml.tuning.TrainValidationSplit(*, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None)

驗證hyper-parameter 調整。將輸入數據集隨機拆分為訓練集和驗證集,並使用驗證集上的評估指標來選擇最佳模型。類似於 CrossValidator ,但隻拆分集合一次。

2.0.0 版中的新函數。

例子

>>> from pyspark.ml.classification import LogisticRegression
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder
>>> from pyspark.ml.tuning import TrainValidationSplitModel
>>> import tempfile
>>> dataset = spark.createDataFrame(
...     [(Vectors.dense([0.0]), 0.0),
...      (Vectors.dense([0.4]), 1.0),
...      (Vectors.dense([0.5]), 0.0),
...      (Vectors.dense([0.6]), 1.0),
...      (Vectors.dense([1.0]), 1.0)] * 10,
...     ["features", "label"]).repartition(1)
>>> lr = LogisticRegression()
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
>>> evaluator = BinaryClassificationEvaluator()
>>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
...     parallelism=1, seed=42)
>>> tvsModel = tvs.fit(dataset)
>>> tvsModel.getTrainRatio()
0.75
>>> tvsModel.validationMetrics
[0.5, ...
>>> path = tempfile.mkdtemp()
>>> model_path = path + "/model"
>>> tvsModel.write().save(model_path)
>>> tvsModelRead = TrainValidationSplitModel.read().load(model_path)
>>> tvsModelRead.validationMetrics
[0.5, ...
>>> evaluator.evaluate(tvsModel.transform(dataset))
0.833...
>>> evaluator.evaluate(tvsModelRead.transform(dataset))
0.833...

相關用法


注:本文由純淨天空篩選整理自spark.apache.org大神的英文原創作品 pyspark.ml.tuning.TrainValidationSplit。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。