本文简要介绍
pyspark.ml.classification.OneVsRest
的用法。用法:
class pyspark.ml.classification.OneVsRest(*, featuresCol='features', labelCol='label', predictionCol='prediction', rawPredictionCol='rawPrediction', classifier=None, weightCol=None, parallelism=1)
将多类分类简化为二元分类。使用一对一的策略执行减少。对于具有 k 个类的多类分类,训练 k 个模型(每类一个)。每个示例都针对所有 k 个模型进行评分,并选择得分最高的模型来标记示例。
2.0.0 版中的新函数。
例子:
>>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" >>> df = spark.read.format("libsvm").load(data_path) >>> lr = LogisticRegression(regParam=0.01) >>> ovr = OneVsRest(classifier=lr) >>> ovr.getRawPredictionCol() 'rawPrediction' >>> ovr.setPredictionCol("newPrediction") OneVsRest... >>> model = ovr.fit(df) >>> model.models[0].coefficients DenseVector([0.5..., -1.0..., 3.4..., 4.2...]) >>> model.models[1].coefficients DenseVector([-2.1..., 3.1..., -2.6..., -2.3...]) >>> model.models[2].coefficients DenseVector([0.3..., -3.4..., 1.0..., -1.1...]) >>> [x.intercept for x in model.models] [-2.7..., -2.5..., -1.3...] >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF() >>> model.transform(test0).head().newPrediction 0.0 >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF() >>> model.transform(test1).head().newPrediction 2.0 >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF() >>> model.transform(test2).head().newPrediction 0.0 >>> model_path = temp_path + "/ovr_model" >>> model.save(model_path) >>> model2 = OneVsRestModel.load(model_path) >>> model2.transform(test0).head().newPrediction 0.0 >>> model.transform(test0).take(1) == model2.transform(test0).take(1) True >>> model.transform(test2).columns ['features', 'rawPrediction', 'newPrediction']
相关用法
- Python pyspark OneHotEncoder用法及代码示例
- Python pyspark create_map用法及代码示例
- Python pyspark date_add用法及代码示例
- Python pyspark DataFrame.to_latex用法及代码示例
- Python pyspark DataStreamReader.schema用法及代码示例
- Python pyspark MultiIndex.size用法及代码示例
- Python pyspark arrays_overlap用法及代码示例
- Python pyspark Series.asof用法及代码示例
- Python pyspark DataFrame.align用法及代码示例
- Python pyspark Index.is_monotonic_decreasing用法及代码示例
- Python pyspark IsotonicRegression用法及代码示例
- Python pyspark DataFrame.plot.bar用法及代码示例
- Python pyspark DataFrame.to_delta用法及代码示例
- Python pyspark element_at用法及代码示例
- Python pyspark explode用法及代码示例
- Python pyspark MultiIndex.hasnans用法及代码示例
- Python pyspark Series.to_frame用法及代码示例
- Python pyspark DataFrame.quantile用法及代码示例
- Python pyspark Column.withField用法及代码示例
- Python pyspark Index.values用法及代码示例
- Python pyspark Index.drop_duplicates用法及代码示例
- Python pyspark aggregate用法及代码示例
- Python pyspark IndexedRowMatrix.computeGramianMatrix用法及代码示例
- Python pyspark DecisionTreeClassifier用法及代码示例
- Python pyspark Index.value_counts用法及代码示例
注:本文由纯净天空筛选整理自spark.apache.org大神的英文原创作品 pyspark.ml.classification.OneVsRest。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。