本文整理匯總了Java中org.apache.spark.ml.classification.RandomForestClassificationModel類的典型用法代碼示例。如果您正苦於以下問題:Java RandomForestClassificationModel類的具體用法?Java RandomForestClassificationModel怎麽用?Java RandomForestClassificationModel使用的例子?那麽, 這裏精選的類代碼示例或許可以為您提供幫助。
RandomForestClassificationModel類屬於org.apache.spark.ml.classification包,在下文中一共展示了RandomForestClassificationModel類的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Java代碼示例。
示例1: encodeModel
import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
@Override
public MiningModel encodeModel(Schema schema){
RandomForestClassificationModel model = getTransformer();
List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel()))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
return miningModel;
}
示例2: getModelInfo
import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
@Override
RandomForestModelInfo getModelInfo(final RandomForestClassificationModel sparkRfModel, final DataFrame df) {
final RandomForestModelInfo modelInfo = new RandomForestModelInfo();
modelInfo.setNumClasses(sparkRfModel.numClasses());
modelInfo.setNumFeatures(sparkRfModel.numFeatures());
modelInfo.setRegression(false); //false for classification
final List<Double> treeWeights = new ArrayList<Double>();
for (double w : sparkRfModel.treeWeights()) {
treeWeights.add(w);
}
modelInfo.setTreeWeights(treeWeights);
final List<DecisionTreeModelInfo> decisionTrees = new ArrayList<>();
for (DecisionTreeModel decisionTreeModel : sparkRfModel.trees()) {
decisionTrees.add(DECISION_TREE_ADAPTER.getModelInfo((DecisionTreeClassificationModel) decisionTreeModel, df));
}
modelInfo.setTrees(decisionTrees);
final Set<String> inputKeys = new LinkedHashSet<String>();
inputKeys.add(sparkRfModel.getFeaturesCol());
inputKeys.add(sparkRfModel.getLabelCol());
modelInfo.setInputKeys(inputKeys);
final Set<String> outputKeys = new LinkedHashSet<String>();
outputKeys.add(sparkRfModel.getPredictionCol());
outputKeys.add(sparkRfModel.getProbabilityCol());
outputKeys.add(sparkRfModel.getRawPredictionCol());
modelInfo.setProbabilityKey(sparkRfModel.getProbabilityCol());
modelInfo.setRawPredictionKey(sparkRfModel.getRawPredictionCol());
modelInfo.setOutputKeys(outputKeys);
return modelInfo;
}
開發者ID:flipkart-incubator,項目名稱:spark-transformers,代碼行數:36,代碼來源:RandomForestClassificationModelInfoAdapter.java
示例3: testRandomForestClassification
import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
@Test
public void testRandomForestClassification() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");
StringIndexerModel stringIndexerModel = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(data);
data = stringIndexerModel.transform(data);
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
// Train a RandomForest model.
RandomForestClassificationModel classificationModel = new RandomForestClassifier()
.setLabelCol("labelIndex")
.setFeaturesCol("features")
.setPredictionCol("prediction")
.setRawPredictionCol("rawPrediction")
.setProbabilityCol("probability")
.fit(trainingData);
byte[] exportedModel = ModelExporter.export(classificationModel, null);
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = classificationModel.transform(testData).select("features", "prediction", "rawPrediction", "probability").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector v = (Vector) row.get(0);
double actual = row.getDouble(1);
double [] actualProbability = ((Vector) row.get(3)).toArray();
double[] actualRaw = ((Vector) row.get(2)).toArray();
Map<String, Object> inputData = new HashMap<String, Object>();
inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
transformer.transform(inputData);
double predicted = (double) inputData.get("prediction");
double[] probability = (double[]) inputData.get("probability");
double[] rawPrediction = (double[]) inputData.get("rawPrediction");
assertEquals(actual, predicted, EPSILON);
assertArrayEquals(actualProbability, probability, EPSILON);
assertArrayEquals(actualRaw, rawPrediction, EPSILON);
}
}
開發者ID:flipkart-incubator,項目名稱:spark-transformers,代碼行數:56,代碼來源:RandomForestClassificationModelInfoAdapterBridgeTest.java
示例4: RandomForestClassificationModelConverter
import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
public RandomForestClassificationModelConverter(RandomForestClassificationModel model){
super(model);
}
示例5: getSource
import org.apache.spark.ml.classification.RandomForestClassificationModel; //導入依賴的package包/類
@Override
public Class<RandomForestClassificationModel> getSource() {
return RandomForestClassificationModel.class;
}
開發者ID:flipkart-incubator,項目名稱:spark-transformers,代碼行數:5,代碼來源:RandomForestClassificationModelInfoAdapter.java