当前位置: 首页>>代码示例>>Java>>正文


Java DecisionTreeModel类代码示例

本文整理汇总了Java中org.apache.spark.mllib.tree.model.DecisionTreeModel的典型用法代码示例。如果您正苦于以下问题:Java DecisionTreeModel类的具体用法?Java DecisionTreeModel怎么用?Java DecisionTreeModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


DecisionTreeModel类属于org.apache.spark.mllib.tree.model包,在下文中一共展示了DecisionTreeModel类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: OnlineFeatureHandler

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public OnlineFeatureHandler(FeatureConstraint featureConstraint,
                            DetectionModel detectionModel,
                            onlineMLEventListener onlineMLEventListener,
                            ControllerConnector controllerConnector) {
    this.featureConstraint = featureConstraint;
    this.detectionModel = detectionModel;
    setAthenaMLFeatureConfiguration(detectionModel.getAthenaMLFeatureConfiguration());

    if (detectionModel instanceof KMeansDetectionModel) {
        this.kMeansModel = (KMeansModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof GaussianMixtureDetectionModel) {
        this.gaussianMixtureModel = (GaussianMixtureModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof DecisionTreeDetectionModel) {
        this.decisionTreeModel = (DecisionTreeModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof NaiveBayesDetectionModel) {
        this.naiveBayesModel = (NaiveBayesModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof RandomForestDetectionModel) {
        this.randomForestModel = (RandomForestModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof GradientBoostedTreesDetectionModel) {
        this.gradientBoostedTreesModel = (GradientBoostedTreesModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof SVMDetectionModel) {
        this.svmModel = (SVMModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LogisticRegressionDetectionModel) {
        this.logisticRegressionModel = (LogisticRegressionModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LinearRegressionDetectionModel) {
        this.linearRegressionModel = (LinearRegressionModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LassoDetectionModel) {
        this.lassoModel = (LassoModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof RidgeRegressionDetectionModel) {
        this.ridgeRegressionModel = (RidgeRegressionModel) detectionModel.getDetectionModel();
    } else {
        //not supported ML model
        System.out.println("Not supported model");
    }

    this.eventDeliveryManager = new EventDeliveryManagerImpl(controllerConnector, new InternalAthenaFeatureEventListener());
    this.eventDeliveryManager.registerOnlineAthenaFeature(null, new QueryIdentifier(QUERY_IDENTIFIER), featureConstraint);
    this.onlineMLEventListener = onlineMLEventListener;
    System.out.println("Install handler!");
}
 
开发者ID:shlee89,项目名称:athena,代码行数:41,代码来源:OnlineFeatureHandler.java

示例2: predictorExampleCounts

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
/**
 * @param trainPointData data to run down trees
 * @param model random decision forest model to count on
 * @return map of predictor index to the number of training examples that reached a
 *  node whose decision is based on that feature. The index is among predictors, not all
 *  features, since there are fewer predictors than features. That is, the index will
 *  match the one used in the {@link RandomForestModel}.
 */
private static Map<Integer,Long> predictorExampleCounts(JavaRDD<LabeledPoint> trainPointData,
                                                        RandomForestModel model) {
  return trainPointData.mapPartitions(data -> {
      IntLongMap featureIndexCount = HashIntLongMaps.newMutableMap();
      data.forEachRemaining(datum -> {
        double[] featureVector = datum.features().toArray();
        for (DecisionTreeModel tree : model.trees()) {
          org.apache.spark.mllib.tree.model.Node node = tree.topNode();
          // This logic cloned from Node.predict:
          while (!node.isLeaf()) {
            Split split = node.split().get();
            int featureIndex = split.feature();
            // Count feature
            featureIndexCount.addValue(featureIndex, 1);
            node = nextNode(featureVector, node, split, featureIndex);
          }
        }
      });
      // Clone to avoid problem with Kryo serializing Koloboke
      return Collections.<Map<Integer,Long>>singleton(
          new HashMap<>(featureIndexCount)).iterator();
  }).reduce(RDFUpdate::merge);
}
 
开发者ID:oncewang,项目名称:oryx2,代码行数:32,代码来源:RDFUpdate.java

示例3: generateKMeansModel

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public DecisionTreeModel generateKMeansModel(JavaRDD<LabeledPoint> parsedData,
                                             DecisionTreeDetectionAlgorithm decisionTreeDetectionAlgorithm,
                                             DecisionTreeModelSummary decisionTreeModelSummary) {
    DecisionTreeModel decisionTreeModel
            = DecisionTree.trainClassifier(parsedData,
            decisionTreeDetectionAlgorithm.getNumClasses(),
            decisionTreeDetectionAlgorithm.getCategoricalFeaturesInfo(),
            decisionTreeDetectionAlgorithm.getImpurity(),
            decisionTreeDetectionAlgorithm.getMaxDepth(),
            decisionTreeDetectionAlgorithm.getMaxBins());
    decisionTreeModelSummary.setDecisionTreeDetectionAlgorithm(decisionTreeDetectionAlgorithm);
    return decisionTreeModel;
}
 
开发者ID:shlee89,项目名称:athena,代码行数:14,代码来源:DecisionTreeDistJob.java

示例4: generateDecisionTreeWithPreprocessing

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public DecisionTreeModel generateDecisionTreeWithPreprocessing(JavaPairRDD<Object, BSONObject> mongoRDD,
                                                               AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                                                               DecisionTreeDetectionAlgorithm decisionTreeDetectionAlgorithm,
                                                               Marking marking,
                                                               DecisionTreeModelSummary decisionTreeModelSummary) {

    return generateKMeansModel(
            rddPreProcessing(mongoRDD, athenaMLFeatureConfiguration, decisionTreeModelSummary,
                    marking),
            decisionTreeDetectionAlgorithm, decisionTreeModelSummary
    );
}
 
开发者ID:shlee89,项目名称:athena,代码行数:13,代码来源:DecisionTreeDistJob.java

示例5: treeNodeExampleCounts

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
/**
 * @param trainPointData data to run down trees
 * @param model random decision forest model to count on
 * @return maps of node IDs to the count of training examples that reached that node, one
 *  per tree in the model
 * @see #predictorExampleCounts(JavaRDD,RandomForestModel)
 */
private static List<Map<Integer,Long>> treeNodeExampleCounts(JavaRDD<LabeledPoint> trainPointData,
                                                             RandomForestModel model) {
  return trainPointData.mapPartitions(data -> {
      DecisionTreeModel[] trees = model.trees();
      List<IntLongMap> treeNodeIDCounts = IntStream.range(0, trees.length).
          mapToObj(i -> HashIntLongMaps.newMutableMap()).collect(Collectors.toList());
      data.forEachRemaining(datum -> {
        double[] featureVector = datum.features().toArray();
        for (int i = 0; i < trees.length; i++) {
          DecisionTreeModel tree = trees[i];
          IntLongMap nodeIDCount = treeNodeIDCounts.get(i);
          org.apache.spark.mllib.tree.model.Node node = tree.topNode();
          // This logic cloned from Node.predict:
          while (!node.isLeaf()) {
            // Count node ID
            nodeIDCount.addValue(node.id(), 1);
            Split split = node.split().get();
            int featureIndex = split.feature();
            node = nextNode(featureVector, node, split, featureIndex);
          }
          nodeIDCount.addValue(node.id(), 1);
        }
      });
      return Collections.<List<Map<Integer,Long>>>singleton(
          treeNodeIDCounts.stream().map(HashMap::new).collect(Collectors.toList())).iterator();
    }
  ).reduce((a, b) -> {
      Preconditions.checkArgument(a.size() == b.size());
      for (int i = 0; i < a.size(); i++) {
        merge(a.get(i), b.get(i));
      }
      return a;
    });
}
 
开发者ID:oncewang,项目名称:oryx2,代码行数:42,代码来源:RDFUpdate.java

示例6: main

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public static void main(String[] args) {

        Logger.getLogger("org").setLevel(Level.WARN);

        SparkConf sparkConf = new SparkConf()
                .setAppName("ExampleSpark")
                .setMaster("local");
        JavaSparkContext jsc = new JavaSparkContext(sparkConf);

        //String in = "data/iris2.data";
        //String out = "data/iris2outSVM.data";

        //double[][] inputs = IOUtils.readMatrix(in, ",");
        //double[] outputs = IOUtils.readVector(out);

        IdxManager idx = IOUtils.deserialize("data/idx.ser");
        IdxManager idxTest = IOUtils.deserialize("data/idx-test.ser");
        double[][] inputs = idx.getData();
        double[] outputs = idx.getLabelsVec();
        double[][] inputsTest = idxTest.getData();
        double[] outputsTest = idxTest.getLabelsVec();
        inputs = HogManager.exportDataFeatures(inputs, idx.getNumOfRows(),
               idx.getNumOfCols());
        inputsTest = HogManager.exportDataFeatures(inputsTest, idx.getNumOfRows(),
               idx.getNumOfCols());

        List<LabeledPoint> pointList = new ArrayList<>();
        for (int i = 0; i < outputs.length; i++) {
            pointList.add(new LabeledPoint(outputs[i], Vectors.dense(inputs[i])));
        }

        List<LabeledPoint> pointListTest = new ArrayList<>();
        for (int i = 0; i < outputsTest.length; i++) {
            pointListTest.add(new LabeledPoint(outputsTest[i],
                    Vectors.dense(inputsTest[i])));
        }

        JavaRDD<LabeledPoint> trainingData = jsc.parallelize(pointList);
        JavaRDD<LabeledPoint> testData = jsc.parallelize(pointListTest);

        // Split the data into training and test sets (30% held out for testing)
        //JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
        //JavaRDD<LabeledPoint> trainingData = splits[0];
        //JavaRDD<LabeledPoint> testData = splits[1];

        // Set parameters.
        // Empty categoricalFeaturesInfo indicates all features are continuous.
        Integer numClasses = 10;
        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
        String impurity = "gini";
        Integer maxDepth = 10;
        Integer maxBins = 256;

        // Train a DecisionTree model for classification.
        long startTime = System.currentTimeMillis();
        final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData,
                numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
        long endTime = System.currentTimeMillis();
        long learnTime = endTime - startTime;

        // Evaluate model on test instances and compute test error
        JavaPairRDD<Double, Double> predictionAndLabel =
                testData.mapToPair(
                        p -> new Tuple2<>(model.predict(p.features()), p.label()));
        Double testErr = 1.0 * predictionAndLabel.filter(
                pl -> !pl._1().equals(pl._2())).count() / testData.count();

        // results
        new File("results").mkdir();
        IOUtils.writeStr("results/dtree_error.data", Double.toString(testErr));
        IOUtils.writeStr("results/dtree_model.data", model.toDebugString());

        double[][] outFinal = new double[outputsTest.length][];
        for (int i = 0; i < outputsTest.length; i++) {
            outFinal[i] = valToVec(model.predict(Vectors.dense(inputsTest[i])));
        }

        ConfusionMatrix cm = new ConfusionMatrix(outFinal, idxTest.getLabels());
        cm.writeClassErrorMatrix("results/confusion_matrix.data");
        IOUtils.writeStr("results/learn_time_ms.data", Long.toString(learnTime));
    }
 
开发者ID:lukago,项目名称:neural-algorithms,代码行数:82,代码来源:ExampleSpark.java

示例7: setDecisionTreeModel

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public void setDecisionTreeModel(DecisionTreeModel decisionTreeModel) {
    this.decisionTreeModel = decisionTreeModel;
}
 
开发者ID:shlee89,项目名称:athena,代码行数:4,代码来源:DecisionTreeDetectionModel.java

示例8: validate

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public void validate(JavaPairRDD<Object, BSONObject> mongoRDD,
                     AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                     DecisionTreeDetectionModel decisionTreeDetectionModel,
                     DecisionTreeValidationSummary decisionTreeValidationSummary) {
    List<AthenaFeatureField> listOfTargetFeatures = athenaMLFeatureConfiguration.getListOfTargetFeatures();
    Map<AthenaFeatureField, Integer> weight = athenaMLFeatureConfiguration.getWeight();
    Marking marking = decisionTreeDetectionModel.getMarking();
    DecisionTreeModel model = (DecisionTreeModel) decisionTreeDetectionModel.getDetectionModel();
    Normalizer normalizer = new Normalizer();

    int numberOfTargetValue = listOfTargetFeatures.size();

    mongoRDD.foreach(new VoidFunction<Tuple2<Object, BSONObject>>() {
        public void call(Tuple2<Object, BSONObject> t) throws UnknownHostException {
            long start2 = System.nanoTime(); // <-- start
            BSONObject feature = (BSONObject) t._2().get(AthenaFeatureField.FEATURE);
            BSONObject idx = (BSONObject) t._2();
            int originLabel = marking.checkClassificationMarkingElements(idx,feature);

            double[] values = new double[numberOfTargetValue];
            for (int j = 0; j < numberOfTargetValue; j++) {
                values[j] = 0;
                if (feature.containsField(listOfTargetFeatures.get(j).getValue())) {
                    Object obj = feature.get(listOfTargetFeatures.get(j).getValue());
                    if (obj instanceof Long) {
                        values[j] = (Long) obj;
                    } else if (obj instanceof Double) {
                        values[j] = (Double) obj;
                    } else if (obj instanceof Boolean) {
                        values[j] = (Boolean) obj ? 1 : 0;
                    } else {
                        return;
                    }

                    //check weight
                    if (weight.containsKey(listOfTargetFeatures.get(j))) {
                        values[j] *= weight.get(listOfTargetFeatures.get(j));
                    }

                    //check absolute
                    if (athenaMLFeatureConfiguration.isAbsolute()){
                        values[j] = Math.abs(values[j]);
                    }
                }
            }

            Vector normedForVal;
            if (athenaMLFeatureConfiguration.isNormalization()) {
                normedForVal = normalizer.transform(Vectors.dense(values));
            } else {
                normedForVal = Vectors.dense(values);
            }

            LabeledPoint p = new LabeledPoint(originLabel,normedForVal);

            int validatedLabel = (int) model.predict(p.features());


            decisionTreeValidationSummary.updateSummary(validatedLabel,idx,feature);

            long end2 = System.nanoTime();
            long result2 = end2 - start2;
            decisionTreeValidationSummary.addTotalNanoSeconds(result2);
        }
    });
    decisionTreeValidationSummary.getAverageNanoSeconds();
    decisionTreeValidationSummary.setDecisionTreeDetectionAlgorithm((DecisionTreeDetectionAlgorithm) decisionTreeDetectionModel.getDetectionAlgorithm());
}
 
开发者ID:shlee89,项目名称:athena,代码行数:69,代码来源:DecisionTreeDistJob.java

示例9: generateDecisionTreeAthenaDetectionModel

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public DecisionTreeDetectionModel generateDecisionTreeAthenaDetectionModel(JavaSparkContext sc,
                                                                           FeatureConstraint featureConstraint,
                                                                           AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                                                                           DetectionAlgorithm detectionAlgorithm,
                                                                           Indexing indexing,
                                                                           Marking marking) {
    DecisionTreeModelSummary decisionTreeModelSummary = new DecisionTreeModelSummary(
            sc.sc(), indexing, marking);

    long start = System.nanoTime(); // <-- start

    DecisionTreeDetectionAlgorithm decisionTreeDetectionAlgorithm = (DecisionTreeDetectionAlgorithm) detectionAlgorithm;

    DecisionTreeDetectionModel decisionTreeDetectionModel = new DecisionTreeDetectionModel();

    decisionTreeDetectionModel.setDecisionTreeDetectionAlgorithm(decisionTreeDetectionAlgorithm);
    decisionTreeModelSummary.setDecisionTreeDetectionAlgorithm(decisionTreeDetectionAlgorithm);
    decisionTreeDetectionModel.setFeatureConstraint(featureConstraint);
    decisionTreeDetectionModel.setAthenaMLFeatureConfiguration(athenaMLFeatureConfiguration);
    decisionTreeDetectionModel.setIndexing(indexing);
    decisionTreeDetectionModel.setMarking(marking);

    JavaPairRDD<Object, BSONObject> mongoRDD;
    mongoRDD = sc.newAPIHadoopRDD(
            mongodbConfig,            // Configuration
            MongoInputFormat.class,   // InputFormat: read from a live cluster.
            Object.class,             // Key class
            BSONObject.class          // Value class
    );

    DecisionTreeDistJob decisionTreeDistJob = new DecisionTreeDistJob();

    DecisionTreeModel decisionTreeModel = decisionTreeDistJob.generateDecisionTreeWithPreprocessing(mongoRDD,
            athenaMLFeatureConfiguration, decisionTreeDetectionAlgorithm, marking, decisionTreeModelSummary);


    decisionTreeDetectionModel.setDecisionTreeModel(decisionTreeModel);
    long end = System.nanoTime(); // <-- start
    long time = end - start;
    decisionTreeModelSummary.setTotalLearningTime(time);
    decisionTreeDetectionModel.setClassificationModelSummary(decisionTreeModelSummary);

    return decisionTreeDetectionModel;
}
 
开发者ID:shlee89,项目名称:athena,代码行数:45,代码来源:MachineLearningManagerImpl.java

示例10: rdfModelToPMML

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
private PMML rdfModelToPMML(RandomForestModel rfModel,
                            CategoricalValueEncodings categoricalValueEncodings,
                            int maxDepth,
                            int maxSplitCandidates,
                            String impurity,
                            List<Map<Integer,Long>> nodeIDCounts,
                            Map<Integer,Long> predictorIndexCounts) {

  boolean classificationTask = rfModel.algo().equals(Algo.Classification());
  Preconditions.checkState(classificationTask == inputSchema.isClassification());

  DecisionTreeModel[] trees = rfModel.trees();

  Model model;
  if (trees.length == 1) {
    model = toTreeModel(trees[0], categoricalValueEncodings, nodeIDCounts.get(0));
  } else {
    MiningModel miningModel = new MiningModel();
    model = miningModel;
    Segmentation.MultipleModelMethod multipleModelMethodType = classificationTask ?
        Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE :
        Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE;
    List<Segment> segments = new ArrayList<>(trees.length);
    for (int treeID = 0; treeID < trees.length; treeID++) {
      TreeModel treeModel =
          toTreeModel(trees[treeID], categoricalValueEncodings, nodeIDCounts.get(treeID));
      segments.add(new Segment()
           .setId(Integer.toString(treeID))
           .setPredicate(new True())
           .setModel(treeModel)
           .setWeight(1.0)); // No weights in MLlib impl now
    }
    miningModel.setSegmentation(new Segmentation(multipleModelMethodType, segments));
  }

  model.setMiningFunction(classificationTask ?
                          MiningFunction.CLASSIFICATION :
                          MiningFunction.REGRESSION);

  double[] importances = countsToImportances(predictorIndexCounts);
  model.setMiningSchema(AppPMMLUtils.buildMiningSchema(inputSchema, importances));
  DataDictionary dictionary =
      AppPMMLUtils.buildDataDictionary(inputSchema, categoricalValueEncodings);

  PMML pmml = PMMLUtils.buildSkeletonPMML();
  pmml.setDataDictionary(dictionary);
  pmml.addModels(model);

  AppPMMLUtils.addExtension(pmml, "maxDepth", maxDepth);
  AppPMMLUtils.addExtension(pmml, "maxSplitCandidates", maxSplitCandidates);
  AppPMMLUtils.addExtension(pmml, "impurity", impurity);

  return pmml;
}
 
开发者ID:oncewang,项目名称:oryx2,代码行数:55,代码来源:RDFUpdate.java

示例11: MLDecisionTreeModel

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public MLDecisionTreeModel(DecisionTreeModel model) {
    this.model = model;
}
 
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:4,代码来源:MLDecisionTreeModel.java

示例12: readExternal

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {

    model = (DecisionTreeModel) in.readObject();
}
 
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:6,代码来源:MLDecisionTreeModel.java

示例13: getModel

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public DecisionTreeModel getModel() {
    return model;
}
 
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:4,代码来源:MLDecisionTreeModel.java

示例14: setModel

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public void setModel(DecisionTreeModel model) {
    this.model = model;
}
 
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:4,代码来源:MLDecisionTreeModel.java

示例15: buildDecisionTreeModel

import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
/**
 * This method builds a decision tree model
 *
 * @param sparkContext JavaSparkContext initialized with the application
 * @param modelID Model ID
 * @param trainingData Training data as a JavaRDD of LabeledPoints
 * @param testingData Testing data as a JavaRDD of LabeledPoints
 * @param workflow Machine learning workflow
 * @param mlModel Deployable machine learning model
 * @throws MLModelBuilderException
 */
private ModelSummary buildDecisionTreeModel(JavaSparkContext sparkContext, long modelID,
                                            JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testingData, Workflow workflow, MLModel mlModel,
                                            SortedMap<Integer, String> includedFeatures, Map<Integer, Integer> categoricalFeatureInfo)
        throws MLModelBuilderException {
    try {
        Map<String, String> hyperParameters = workflow.getHyperParameters();
        DecisionTree decisionTree = new DecisionTree();
        DecisionTreeModel decisionTreeModel = decisionTree.train(trainingData, getNoOfClasses(mlModel),
                categoricalFeatureInfo, hyperParameters.get(MLConstants.IMPURITY),
                Integer.parseInt(hyperParameters.get(MLConstants.MAX_DEPTH)),
                Integer.parseInt(hyperParameters.get(MLConstants.MAX_BINS)));

        // remove from cache
        trainingData.unpersist();
        // add test data to cache
        testingData.cache();

        JavaPairRDD<Double, Double> predictionsAndLabels = decisionTree.test(decisionTreeModel, testingData)
                .cache();
        ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = SparkModelUtils
                .getClassClassificationModelSummary(sparkContext, testingData, predictionsAndLabels);

        // remove from cache
        testingData.unpersist();

        mlModel.setModel(new MLDecisionTreeModel(decisionTreeModel));

        classClassificationAndRegressionModelSummary.setFeatures(includedFeatures.values().toArray(new String[0]));
        classClassificationAndRegressionModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.DECISION_TREE.toString());

        MulticlassMetrics multiclassMetrics = getMulticlassMetrics(sparkContext, predictionsAndLabels);

        predictionsAndLabels.unpersist();

        classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(getMulticlassConfusionMatrix(
                multiclassMetrics, mlModel));
        Double modelAccuracy = getModelAccuracy(multiclassMetrics);
        classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
        classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

        return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
        throw new MLModelBuilderException(
                "An error occurred while building decision tree model: " + e.getMessage(), e);
    }

}
 
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:59,代码来源:SupervisedSparkModelBuilder.java


注:本文中的org.apache.spark.mllib.tree.model.DecisionTreeModel类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。