本文简要介绍
pyspark.ml.classification.DecisionTreeClassifier
的用法。用法:
class pyspark.ml.classification.DecisionTreeClassifier(*, featuresCol='features', labelCol='label', predictionCol='prediction', probabilityCol='probability', rawPredictionCol='rawPrediction', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity='gini', seed=None, weightCol=None, leafCol='', minWeightFractionPerNode=0.0)
Decision tree 学习分类算法。它支持二进制和多类标签,以及连续和分类特征。
1.4.0 版中的新函数。
例子:
>>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed", leafCol="leafId") >>> model = dt.fit(td) >>> model.getLabelCol() 'indexed' >>> model.setFeaturesCol("features") DecisionTreeClassificationModel... >>> model.numNodes 3 >>> model.depth 1 >>> model.featureImportances SparseVector(1, {0: 1.0}) >>> model.numFeatures 1 >>> model.numClasses 2 >>> print(model.toDebugString) DecisionTreeClassificationModel...depth=1, numNodes=3... >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.predict(test0.head().features) 0.0 >>> model.predictRaw(test0.head().features) DenseVector([1.0, 0.0]) >>> model.predictProbability(test0.head().features) DenseVector([1.0, 0.0]) >>> result = model.transform(test0).head() >>> result.prediction 0.0 >>> result.probability DenseVector([1.0, 0.0]) >>> result.rawPrediction DenseVector([1.0, 0.0]) >>> result.leafId 0.0 >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 >>> dtc_path = temp_path + "/dtc" >>> dt.save(dtc_path) >>> dt2 = DecisionTreeClassifier.load(dtc_path) >>> dt2.getMaxDepth() 2 >>> model_path = temp_path + "/dtc_model" >>> model.save(model_path) >>> model2 = DecisionTreeClassificationModel.load(model_path) >>> model.featureImportances == model2.featureImportances True >>> model.transform(test0).take(1) == model2.transform(test0).take(1) True >>> df3 = spark.createDataFrame([ ... (1.0, 0.2, Vectors.dense(1.0)), ... (1.0, 0.8, Vectors.dense(1.0)), ... (0.0, 1.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) >>> si3 = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model3 = si3.fit(df3) >>> td3 = si_model3.transform(df3) >>> dt3 = DecisionTreeClassifier(maxDepth=2, weightCol="weight", labelCol="indexed") >>> model3 = dt3.fit(td3) >>> print(model3.toDebugString) DecisionTreeClassificationModel...depth=1, numNodes=3...
相关用法
- Python pyspark DecisionTree.trainClassifier用法及代码示例
- Python pyspark DecisionTreeRegressor用法及代码示例
- Python pyspark DecisionTree.trainRegressor用法及代码示例
- Python pyspark DenseMatrix.toArray用法及代码示例
- Python pyspark DenseVector.parse用法及代码示例
- Python pyspark DenseVector用法及代码示例
- Python pyspark DenseVector.squared_distance用法及代码示例
- Python pyspark DenseVector.norm用法及代码示例
- Python pyspark DenseVector.dot用法及代码示例
- Python pyspark DataFrame.to_latex用法及代码示例
- Python pyspark DataStreamReader.schema用法及代码示例
- Python pyspark DataFrame.align用法及代码示例
- Python pyspark DataFrame.plot.bar用法及代码示例
- Python pyspark DataFrame.to_delta用法及代码示例
- Python pyspark DataFrame.quantile用法及代码示例
- Python pyspark DataFrameWriter.partitionBy用法及代码示例
- Python pyspark DataFrame.cumsum用法及代码示例
- Python pyspark DatetimeIndex.is_month_start用法及代码示例
- Python pyspark DataFrame.iloc用法及代码示例
- Python pyspark DatetimeIndex.normalize用法及代码示例
- Python pyspark DataFrame.dropDuplicates用法及代码示例
- Python pyspark DatetimeIndex.is_month_end用法及代码示例
- Python pyspark DataFrame.printSchema用法及代码示例
- Python pyspark DataFrame.to_table用法及代码示例
- Python pyspark DatetimeIndex.is_quarter_start用法及代码示例
注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.ml.classification.DecisionTreeClassifier。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。