當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python pyspark GBTClassifier用法及代碼示例

本文簡要介紹 pyspark.ml.classification.GBTClassifier 的用法。

用法:

class pyspark.ml.classification.GBTClassifier(*, featuresCol='features', labelCol='label', predictionCol='prediction', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType='logistic', maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity='variance', featureSubsetStrategy='all', validationTol=0.01, validationIndicatorCol=None, leafCol='', minWeightFractionPerNode=0.0, weightCol=None)

Gradient-Boosted Trees (GBTs) 學習分類算法。它支持二進製標簽,以及連續和分類特征。

1.4.0 版中的新函數。

注意

當前不支持多類標簽。

該實施基於:J.H.弗裏德曼。 “隨機梯度提升。” 1999 年。

梯度提升與 TreeBoost:

  • 此實現適用於隨機梯度提升,而不適用於 TreeBoost。

  • 兩種算法都通過最小化損失函數來學習樹集成。

  • TreeBoost(Friedman,1999)根據損失函數另外修改了樹葉節點的輸出,而原始梯度增強方法則沒有。

  • 我們預計將來會實現TreeBoost:SPARK-4240

例子

>>> from numpy import allclose
>>> from pyspark.ml.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = spark.createDataFrame([
...     (1.0, Vectors.dense(1.0)),
...     (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42,
...     leafCol="leafId")
>>> gbt.setMaxIter(5)
GBTClassifier...
>>> gbt.setMinWeightFractionPerNode(0.049)
GBTClassifier...
>>> gbt.getMaxIter()
5
>>> gbt.getFeatureSubsetStrategy()
'all'
>>> model = gbt.fit(td)
>>> model.getLabelCol()
'indexed'
>>> model.setFeaturesCol("features")
GBTClassificationModel...
>>> model.setThresholds([0.3, 0.7])
GBTClassificationModel...
>>> model.getThresholds()
[0.3, 0.7]
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.predict(test0.head().features)
0.0
>>> model.predictRaw(test0.head().features)
DenseVector([1.1697, -1.1697])
>>> model.predictProbability(test0.head().features)
DenseVector([0.9121, 0.0879])
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
>>> result.leafId
DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
>>> model.totalNumNodes
15
>>> print(model.toDebugString)
GBTClassificationModel...numTrees=5...
>>> gbtc_path = temp_path + "gbtc"
>>> gbt.save(gbtc_path)
>>> gbt2 = GBTClassifier.load(gbtc_path)
>>> gbt2.getMaxDepth()
2
>>> model_path = temp_path + "gbtc_model"
>>> model.save(model_path)
>>> model2 = GBTClassificationModel.load(model_path)
>>> model.featureImportances == model2.featureImportances
True
>>> model.treeWeights == model2.treeWeights
True
>>> model.transform(test0).take(1) == model2.transform(test0).take(1)
True
>>> model.trees
[DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]
>>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
...              ["indexed", "features"])
>>> model.evaluateEachIteration(validation)
[0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
>>> model.numClasses
2
>>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
>>> gbt.getValidationIndicatorCol()
'validationIndicator'
>>> gbt.getValidationTol()
0.01

相關用法


注:本文由純淨天空篩選整理自spark.apache.org大神的英文原創作品 pyspark.ml.classification.GBTClassifier。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。