當前位置: 首頁>>代碼示例>>Java>>正文


Java RandomForestClassificationModel類代碼示例

本文整理匯總了Java中org.apache.spark.ml.classification.RandomForestClassificationModel的典型用法代碼示例。如果您正苦於以下問題:Java RandomForestClassificationModel類的具體用法?Java RandomForestClassificationModel怎麽用?Java RandomForestClassificationModel使用的例子?那麽, 這裏精選的類代碼示例或許可以為您提供幫助。


RandomForestClassificationModel類屬於org.apache.spark.ml.classification包,在下文中一共展示了RandomForestClassificationModel類的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Java代碼示例。

示例1: encodeModel

import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
@Override
public MiningModel encodeModel(Schema schema){
	RandomForestClassificationModel model = getTransformer();

	List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);

	MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel()))
		.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));

	return miningModel;
}
 
開發者ID:jpmml,項目名稱:jpmml-sparkml,代碼行數:12,代碼來源:RandomForestClassificationModelConverter.java

示例2: getModelInfo

import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
@Override
RandomForestModelInfo getModelInfo(final RandomForestClassificationModel sparkRfModel, final DataFrame df) {
    final RandomForestModelInfo modelInfo = new RandomForestModelInfo();

    modelInfo.setNumClasses(sparkRfModel.numClasses());
    modelInfo.setNumFeatures(sparkRfModel.numFeatures());
    modelInfo.setRegression(false); //false for classification

    final List<Double> treeWeights = new ArrayList<Double>();
    for (double w : sparkRfModel.treeWeights()) {
        treeWeights.add(w);
    }
    modelInfo.setTreeWeights(treeWeights);

    final List<DecisionTreeModelInfo> decisionTrees = new ArrayList<>();
    for (DecisionTreeModel decisionTreeModel : sparkRfModel.trees()) {
        decisionTrees.add(DECISION_TREE_ADAPTER.getModelInfo((DecisionTreeClassificationModel) decisionTreeModel, df));
    }
    modelInfo.setTrees(decisionTrees);

    final Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add(sparkRfModel.getFeaturesCol());
    inputKeys.add(sparkRfModel.getLabelCol());
    modelInfo.setInputKeys(inputKeys);

    final Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add(sparkRfModel.getPredictionCol());
    outputKeys.add(sparkRfModel.getProbabilityCol());
    outputKeys.add(sparkRfModel.getRawPredictionCol());
    modelInfo.setProbabilityKey(sparkRfModel.getProbabilityCol());
    modelInfo.setRawPredictionKey(sparkRfModel.getRawPredictionCol());
    modelInfo.setOutputKeys(outputKeys);

    return modelInfo;
}
 
開發者ID:flipkart-incubator,項目名稱:spark-transformers,代碼行數:36,代碼來源:RandomForestClassificationModelInfoAdapter.java

示例3: testRandomForestClassification

import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
@Test
public void testRandomForestClassification() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");

    StringIndexerModel stringIndexerModel = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex")
            .fit(data);

    data = stringIndexerModel.transform(data);

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a RandomForest model.
    RandomForestClassificationModel classificationModel = new RandomForestClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features")
            .setPredictionCol("prediction")
            .setRawPredictionCol("rawPrediction")
            .setProbabilityCol("probability")
            .fit(trainingData);


    byte[] exportedModel = ModelExporter.export(classificationModel, null);

    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = classificationModel.transform(testData).select("features", "prediction", "rawPrediction", "probability").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);
        double [] actualProbability = ((Vector) row.get(3)).toArray();
        double[] actualRaw = ((Vector) row.get(2)).toArray();

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get("prediction");
        double[] probability = (double[]) inputData.get("probability");
        double[] rawPrediction = (double[]) inputData.get("rawPrediction");

        assertEquals(actual, predicted, EPSILON);
        assertArrayEquals(actualProbability, probability, EPSILON);
        assertArrayEquals(actualRaw, rawPrediction, EPSILON);


    }

}
 
開發者ID:flipkart-incubator,項目名稱:spark-transformers,代碼行數:56,代碼來源:RandomForestClassificationModelInfoAdapterBridgeTest.java

示例4: RandomForestClassificationModelConverter

import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
public RandomForestClassificationModelConverter(RandomForestClassificationModel model){
	super(model);
}
 
開發者ID:jpmml,項目名稱:jpmml-sparkml,代碼行數:4,代碼來源:RandomForestClassificationModelConverter.java

示例5: getSource

import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
@Override
public Class<RandomForestClassificationModel> getSource() {
    return RandomForestClassificationModel.class;
}
 
開發者ID:flipkart-incubator,項目名稱:spark-transformers,代碼行數:5,代碼來源:RandomForestClassificationModelInfoAdapter.java


注:本文中的org.apache.spark.ml.classification.RandomForestClassificationModel類示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。