本文简要介绍
pyspark.ml.classification.GBTClassifier
的用法。用法:
class pyspark.ml.classification.GBTClassifier(*, featuresCol='features', labelCol='label', predictionCol='prediction', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType='logistic', maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity='variance', featureSubsetStrategy='all', validationTol=0.01, validationIndicatorCol=None, leafCol='', minWeightFractionPerNode=0.0, weightCol=None)
Gradient-Boosted Trees (GBTs) 学习分类算法。它支持二进制标签,以及连续和分类特征。
1.4.0 版中的新函数。
注意:
当前不支持多类标签。
该实施基于:J.H.弗里德曼。 “随机梯度提升。” 1999 年。
梯度提升与 TreeBoost:
此实现适用于随机梯度提升,而不适用于 TreeBoost。
两种算法都通过最小化损失函数来学习树集成。
TreeBoost(Friedman,1999)根据损失函数另外修改了树叶节点的输出,而原始梯度增强方法则没有。
我们预计将来会实现TreeBoost:SPARK-4240
例子:
>>> from numpy import allclose >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42, ... leafCol="leafId") >>> gbt.setMaxIter(5) GBTClassifier... >>> gbt.setMinWeightFractionPerNode(0.049) GBTClassifier... >>> gbt.getMaxIter() 5 >>> gbt.getFeatureSubsetStrategy() 'all' >>> model = gbt.fit(td) >>> model.getLabelCol() 'indexed' >>> model.setFeaturesCol("features") GBTClassificationModel... >>> model.setThresholds([0.3, 0.7]) GBTClassificationModel... >>> model.getThresholds() [0.3, 0.7] >>> model.featureImportances SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.predict(test0.head().features) 0.0 >>> model.predictRaw(test0.head().features) DenseVector([1.1697, -1.1697]) >>> model.predictProbability(test0.head().features) DenseVector([0.9121, 0.0879]) >>> result = model.transform(test0).head() >>> result.prediction 0.0 >>> result.leafId DenseVector([0.0, 0.0, 0.0, 0.0, 0.0]) >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 >>> model.totalNumNodes 15 >>> print(model.toDebugString) GBTClassificationModel...numTrees=5... >>> gbtc_path = temp_path + "gbtc" >>> gbt.save(gbtc_path) >>> gbt2 = GBTClassifier.load(gbtc_path) >>> gbt2.getMaxDepth() 2 >>> model_path = temp_path + "gbtc_model" >>> model.save(model_path) >>> model2 = GBTClassificationModel.load(model_path) >>> model.featureImportances == model2.featureImportances True >>> model.treeWeights == model2.treeWeights True >>> model.transform(test0).take(1) == model2.transform(test0).take(1) True >>> model.trees [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...] >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)], ... ["indexed", "features"]) >>> model.evaluateEachIteration(validation) [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] >>> model.numClasses 2 >>> gbt = gbt.setValidationIndicatorCol("validationIndicator") >>> gbt.getValidationIndicatorCol() 'validationIndicator' >>> gbt.getValidationTol() 0.01
相关用法
- Python pyspark GBTRegressor用法及代码示例
- Python pyspark GroupBy.mean用法及代码示例
- Python pyspark GroupBy.head用法及代码示例
- Python pyspark GroupedData.applyInPandas用法及代码示例
- Python pyspark GroupBy.cumsum用法及代码示例
- Python pyspark GroupBy.rank用法及代码示例
- Python pyspark GaussianMixtureModel用法及代码示例
- Python pyspark GroupBy.bfill用法及代码示例
- Python pyspark GradientBoostedTrees.trainRegressor用法及代码示例
- Python pyspark GroupBy.cummin用法及代码示例
- Python pyspark GroupBy.cummax用法及代码示例
- Python pyspark GroupedData.mean用法及代码示例
- Python pyspark GroupBy.fillna用法及代码示例
- Python pyspark GroupBy.apply用法及代码示例
- Python pyspark GroupedData.agg用法及代码示例
- Python pyspark GroupedData.pivot用法及代码示例
- Python pyspark GroupBy.diff用法及代码示例
- Python pyspark GroupBy.filter用法及代码示例
- Python pyspark GroupBy.transform用法及代码示例
- Python pyspark GroupedData.apply用法及代码示例
- Python pyspark GroupBy.cumcount用法及代码示例
- Python pyspark GroupedData.max用法及代码示例
- Python pyspark GaussianMixture用法及代码示例
- Python pyspark GroupedData.count用法及代码示例
- Python pyspark GroupedData.min用法及代码示例
注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.ml.classification.GBTClassifier。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。