当前位置: 首页>>代码示例>>Java>>正文


Java RandomForestClassificationModel类代码示例

本文整理汇总了Java中org.apache.spark.ml.classification.RandomForestClassificationModel的典型用法代码示例。如果您正苦于以下问题:Java RandomForestClassificationModel类的具体用法?Java RandomForestClassificationModel怎么用?Java RandomForestClassificationModel使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


RandomForestClassificationModel类属于org.apache.spark.ml.classification包,在下文中一共展示了RandomForestClassificationModel类的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: encodeModel

import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
@Override
public MiningModel encodeModel(Schema schema){
	RandomForestClassificationModel model = getTransformer();

	List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);

	MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel()))
		.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));

	return miningModel;
}
 
开发者ID:jpmml,项目名称:jpmml-sparkml,代码行数:12,代码来源:RandomForestClassificationModelConverter.java

示例2: getModelInfo

import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
@Override
RandomForestModelInfo getModelInfo(final RandomForestClassificationModel sparkRfModel, final DataFrame df) {
    final RandomForestModelInfo modelInfo = new RandomForestModelInfo();

    modelInfo.setNumClasses(sparkRfModel.numClasses());
    modelInfo.setNumFeatures(sparkRfModel.numFeatures());
    modelInfo.setRegression(false); //false for classification

    final List<Double> treeWeights = new ArrayList<Double>();
    for (double w : sparkRfModel.treeWeights()) {
        treeWeights.add(w);
    }
    modelInfo.setTreeWeights(treeWeights);

    final List<DecisionTreeModelInfo> decisionTrees = new ArrayList<>();
    for (DecisionTreeModel decisionTreeModel : sparkRfModel.trees()) {
        decisionTrees.add(DECISION_TREE_ADAPTER.getModelInfo((DecisionTreeClassificationModel) decisionTreeModel, df));
    }
    modelInfo.setTrees(decisionTrees);

    final Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add(sparkRfModel.getFeaturesCol());
    inputKeys.add(sparkRfModel.getLabelCol());
    modelInfo.setInputKeys(inputKeys);

    final Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add(sparkRfModel.getPredictionCol());
    outputKeys.add(sparkRfModel.getProbabilityCol());
    outputKeys.add(sparkRfModel.getRawPredictionCol());
    modelInfo.setProbabilityKey(sparkRfModel.getProbabilityCol());
    modelInfo.setRawPredictionKey(sparkRfModel.getRawPredictionCol());
    modelInfo.setOutputKeys(outputKeys);

    return modelInfo;
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:36,代码来源:RandomForestClassificationModelInfoAdapter.java

示例3: testRandomForestClassification

import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
@Test
public void testRandomForestClassification() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");

    StringIndexerModel stringIndexerModel = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex")
            .fit(data);

    data = stringIndexerModel.transform(data);

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a RandomForest model.
    RandomForestClassificationModel classificationModel = new RandomForestClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features")
            .setPredictionCol("prediction")
            .setRawPredictionCol("rawPrediction")
            .setProbabilityCol("probability")
            .fit(trainingData);


    byte[] exportedModel = ModelExporter.export(classificationModel, null);

    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = classificationModel.transform(testData).select("features", "prediction", "rawPrediction", "probability").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);
        double [] actualProbability = ((Vector) row.get(3)).toArray();
        double[] actualRaw = ((Vector) row.get(2)).toArray();

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get("prediction");
        double[] probability = (double[]) inputData.get("probability");
        double[] rawPrediction = (double[]) inputData.get("rawPrediction");

        assertEquals(actual, predicted, EPSILON);
        assertArrayEquals(actualProbability, probability, EPSILON);
        assertArrayEquals(actualRaw, rawPrediction, EPSILON);


    }

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:56,代码来源:RandomForestClassificationModelInfoAdapterBridgeTest.java

示例4: RandomForestClassificationModelConverter

import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
public RandomForestClassificationModelConverter(RandomForestClassificationModel model){
	super(model);
}
 
开发者ID:jpmml,项目名称:jpmml-sparkml,代码行数:4,代码来源:RandomForestClassificationModelConverter.java

示例5: getSource

import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
@Override
public Class<RandomForestClassificationModel> getSource() {
    return RandomForestClassificationModel.class;
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:5,代码来源:RandomForestClassificationModelInfoAdapter.java


注:本文中的org.apache.spark.ml.classification.RandomForestClassificationModel类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。