本文简要介绍
pyspark.mllib.tree.GradientBoostedTrees.trainClassifier
的用法。用法:
classmethod trainClassifier(data, categoricalFeaturesInfo, loss='logLoss', numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32)
训练梯度增强树模型进行分类。
版本 1.3.0 中的新函数。
- data:
pyspark.RDD
训练数据集:LabeledPoint 的 RDD。标签应采用值 {0, 1}。
- categoricalFeaturesInfo:dict
Map存储分类特征的数量。条目 (n -> k) 表示特征 n 是分类的,其中 k 个类别从 0 开始索引:{0, 1, ..., k-1}。
- loss:str,可选
梯度提升期间用于最小化的损失函数。支持的值:“logLoss”、“leastSquaresError”、“leastAbsoluteError”。 (默认:“logLoss”)
- numIterations:整数,可选
提升的迭代次数。 (默认值:100)
- learningRate:浮点数,可选
缩小每个估计器贡献的学习率。学习率应该在区间 (0, 1] 之间。(默认值:0.1)
- maxDepth:整数,可选
树的最大深度(例如,深度 0 表示 1 个叶节点,深度 1 表示 1 个内部节点 + 2 个叶节点)。 (默认值:3)
- maxBins:整数,可选
用于分割要素的最大箱数。 DecisionTree 要求 maxBins >= 最大类别。 (默认值:32)
- data:
参数:
返回:
例子:
>>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import GradientBoostedTrees >>> >>> data = [ ... LabeledPoint(0.0, [0.0]), ... LabeledPoint(0.0, [1.0]), ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] >>> >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10) >>> model.numTrees() 10 >>> model.totalNumNodes() 30 >>> print(model) # it already has newline TreeEnsembleModel classifier with 10 trees >>> model.predict([2.0]) 1.0 >>> model.predict([0.0]) 0.0 >>> rdd = sc.parallelize([[2.0], [0.0]]) >>> model.predict(rdd).collect() [1.0, 0.0]
相关用法
- Python pyspark GradientBoostedTrees.trainRegressor用法及代码示例
- Python pyspark GroupBy.mean用法及代码示例
- Python pyspark GroupBy.head用法及代码示例
- Python pyspark GroupedData.applyInPandas用法及代码示例
- Python pyspark GroupBy.cumsum用法及代码示例
- Python pyspark GroupBy.rank用法及代码示例
- Python pyspark GroupBy.bfill用法及代码示例
- Python pyspark GroupBy.cummin用法及代码示例
- Python pyspark GroupBy.cummax用法及代码示例
- Python pyspark GroupedData.mean用法及代码示例
- Python pyspark GroupBy.fillna用法及代码示例
- Python pyspark GroupBy.apply用法及代码示例
- Python pyspark GroupedData.agg用法及代码示例
- Python pyspark GroupedData.pivot用法及代码示例
- Python pyspark GroupBy.diff用法及代码示例
- Python pyspark GroupBy.filter用法及代码示例
- Python pyspark GroupBy.transform用法及代码示例
- Python pyspark GroupedData.apply用法及代码示例
- Python pyspark GroupBy.cumcount用法及代码示例
- Python pyspark GroupedData.max用法及代码示例
- Python pyspark GroupedData.count用法及代码示例
- Python pyspark GroupedData.min用法及代码示例
- Python pyspark GroupBy.idxmax用法及代码示例
- Python pyspark GroupBy.shift用法及代码示例
- Python pyspark GroupBy.idxmin用法及代码示例
注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.mllib.tree.GradientBoostedTrees.trainClassifier。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。