本文整理汇总了Java中org.apache.spark.ml.classification.RandomForestClassificationModel类的典型用法代码示例。如果您正苦于以下问题:Java RandomForestClassificationModel类的具体用法?Java RandomForestClassificationModel怎么用?Java RandomForestClassificationModel使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
RandomForestClassificationModel类属于org.apache.spark.ml.classification包,在下文中一共展示了RandomForestClassificationModel类的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: encodeModel
import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
@Override
public MiningModel encodeModel(Schema schema){
RandomForestClassificationModel model = getTransformer();
List<TreeModel> treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(model, schema);
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel()))
.setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels));
return miningModel;
}
示例2: getModelInfo
import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
@Override
RandomForestModelInfo getModelInfo(final RandomForestClassificationModel sparkRfModel, final DataFrame df) {
final RandomForestModelInfo modelInfo = new RandomForestModelInfo();
modelInfo.setNumClasses(sparkRfModel.numClasses());
modelInfo.setNumFeatures(sparkRfModel.numFeatures());
modelInfo.setRegression(false); //false for classification
final List<Double> treeWeights = new ArrayList<Double>();
for (double w : sparkRfModel.treeWeights()) {
treeWeights.add(w);
}
modelInfo.setTreeWeights(treeWeights);
final List<DecisionTreeModelInfo> decisionTrees = new ArrayList<>();
for (DecisionTreeModel decisionTreeModel : sparkRfModel.trees()) {
decisionTrees.add(DECISION_TREE_ADAPTER.getModelInfo((DecisionTreeClassificationModel) decisionTreeModel, df));
}
modelInfo.setTrees(decisionTrees);
final Set<String> inputKeys = new LinkedHashSet<String>();
inputKeys.add(sparkRfModel.getFeaturesCol());
inputKeys.add(sparkRfModel.getLabelCol());
modelInfo.setInputKeys(inputKeys);
final Set<String> outputKeys = new LinkedHashSet<String>();
outputKeys.add(sparkRfModel.getPredictionCol());
outputKeys.add(sparkRfModel.getProbabilityCol());
outputKeys.add(sparkRfModel.getRawPredictionCol());
modelInfo.setProbabilityKey(sparkRfModel.getProbabilityCol());
modelInfo.setRawPredictionKey(sparkRfModel.getRawPredictionCol());
modelInfo.setOutputKeys(outputKeys);
return modelInfo;
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:36,代码来源:RandomForestClassificationModelInfoAdapter.java
示例3: testRandomForestClassification
import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
@Test
public void testRandomForestClassification() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");
StringIndexerModel stringIndexerModel = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(data);
data = stringIndexerModel.transform(data);
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
// Train a RandomForest model.
RandomForestClassificationModel classificationModel = new RandomForestClassifier()
.setLabelCol("labelIndex")
.setFeaturesCol("features")
.setPredictionCol("prediction")
.setRawPredictionCol("rawPrediction")
.setProbabilityCol("probability")
.fit(trainingData);
byte[] exportedModel = ModelExporter.export(classificationModel, null);
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = classificationModel.transform(testData).select("features", "prediction", "rawPrediction", "probability").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector v = (Vector) row.get(0);
double actual = row.getDouble(1);
double [] actualProbability = ((Vector) row.get(3)).toArray();
double[] actualRaw = ((Vector) row.get(2)).toArray();
Map<String, Object> inputData = new HashMap<String, Object>();
inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
transformer.transform(inputData);
double predicted = (double) inputData.get("prediction");
double[] probability = (double[]) inputData.get("probability");
double[] rawPrediction = (double[]) inputData.get("rawPrediction");
assertEquals(actual, predicted, EPSILON);
assertArrayEquals(actualProbability, probability, EPSILON);
assertArrayEquals(actualRaw, rawPrediction, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:56,代码来源:RandomForestClassificationModelInfoAdapterBridgeTest.java
示例4: RandomForestClassificationModelConverter
import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
public RandomForestClassificationModelConverter(RandomForestClassificationModel model){
super(model);
}
示例5: getSource
import org.apache.spark.ml.classification.RandomForestClassificationModel; //导入依赖的package包/类
@Override
public Class<RandomForestClassificationModel> getSource() {
return RandomForestClassificationModel.class;
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:5,代码来源:RandomForestClassificationModelInfoAdapter.java