本文簡要介紹
pyspark.mllib.tree.DecisionTree.trainClassifier
的用法。用法:
classmethod trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity='gini', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)
訓練用於分類的決策樹模型。
1.1.0 版中的新函數。
- data:
pyspark.RDD
訓練數據:LabeledPoint 的 RDD。標簽應采用值 {0, 1, ..., numClasses-1}。
- numClasses:int
分類的類數。
- categoricalFeaturesInfo:dict
Map存儲分類特征的數量。條目 (n -> k) 表示特征 n 是分類的,其中 k 個類別從 0 開始索引:{0, 1, ..., k-1}。
- impurity:str,可選
用於信息增益計算的標準。支持的值:“gini” 或“entropy”。 (默認:“gini”)
- maxDepth:整數,可選
樹的最大深度(例如,深度 0 表示 1 個葉節點,深度 1 表示 1 個內部節點 + 2 個葉節點)。 (默認值:5)
- maxBins:整數,可選
用於在每個節點處查找拆分的箱數。 (默認值:32)
- minInstancesPerNode:整數,可選
子節點創建父拆分所需的最小實例數。 (默認值:1)
- minInfoGain:浮點數,可選
創建拆分所需的最小信息增益。 (默認值:0.0)
- data:
參數:
返回:
例子:
>>> from numpy import array >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree >>> >>> data = [ ... LabeledPoint(0.0, [0.0]), ... LabeledPoint(1.0, [1.0]), ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) >>> print(model) DecisionTreeModel classifier of depth 1 with 3 nodes
>>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) Predict: 1.0 >>> model.predict(array([1.0])) 1.0 >>> model.predict(array([0.0])) 0.0 >>> rdd = sc.parallelize([[1.0], [0.0]]) >>> model.predict(rdd).collect() [1.0, 0.0]
相關用法
- Python pyspark DecisionTree.trainRegressor用法及代碼示例
- Python pyspark DecisionTreeClassifier用法及代碼示例
- Python pyspark DecisionTreeRegressor用法及代碼示例
- Python pyspark DenseMatrix.toArray用法及代碼示例
- Python pyspark DenseVector.parse用法及代碼示例
- Python pyspark DenseVector用法及代碼示例
- Python pyspark DenseVector.squared_distance用法及代碼示例
- Python pyspark DenseVector.norm用法及代碼示例
- Python pyspark DenseVector.dot用法及代碼示例
- Python pyspark DataFrame.to_latex用法及代碼示例
- Python pyspark DataStreamReader.schema用法及代碼示例
- Python pyspark DataFrame.align用法及代碼示例
- Python pyspark DataFrame.plot.bar用法及代碼示例
- Python pyspark DataFrame.to_delta用法及代碼示例
- Python pyspark DataFrame.quantile用法及代碼示例
- Python pyspark DataFrameWriter.partitionBy用法及代碼示例
- Python pyspark DataFrame.cumsum用法及代碼示例
- Python pyspark DatetimeIndex.is_month_start用法及代碼示例
- Python pyspark DataFrame.iloc用法及代碼示例
- Python pyspark DatetimeIndex.normalize用法及代碼示例
- Python pyspark DataFrame.dropDuplicates用法及代碼示例
- Python pyspark DatetimeIndex.is_month_end用法及代碼示例
- Python pyspark DataFrame.printSchema用法及代碼示例
- Python pyspark DataFrame.to_table用法及代碼示例
- Python pyspark DatetimeIndex.is_quarter_start用法及代碼示例
注:本文由純淨天空篩選整理自spark.apache.org大神的英文原創作品 pyspark.mllib.tree.DecisionTree.trainClassifier。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。