当前位置: 首页>>代码示例>>Java>>正文


Java RandomForestModel类代码示例

本文整理汇总了Java中org.apache.spark.mllib.tree.model.RandomForestModel的典型用法代码示例。如果您正苦于以下问题:Java RandomForestModel类的具体用法?Java RandomForestModel怎么用?Java RandomForestModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


RandomForestModel类属于org.apache.spark.mllib.tree.model包,在下文中一共展示了RandomForestModel类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: generateKMeansModel

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public RandomForestModel generateKMeansModel(JavaRDD<LabeledPoint> parsedData,
                                             RandomForestDetectionAlgorithm randomForestDetectionAlgorithm,
                                             RandomForestModelSummary randomForestModelSummary) {
    RandomForestModel randomForestModel
            = RandomForest.trainClassifier(parsedData,
            randomForestDetectionAlgorithm.getNumClasses(),
            randomForestDetectionAlgorithm.getCategoricalFeaturesInfo(),
            randomForestDetectionAlgorithm.getNumTrees(),
            randomForestDetectionAlgorithm.getFeatureSubsetStrategy(),
            randomForestDetectionAlgorithm.getImpurity(),
            randomForestDetectionAlgorithm.getMaxDepth(),
            randomForestDetectionAlgorithm.getMaxBins(),
            randomForestDetectionAlgorithm.seed);

    randomForestModelSummary.setRandomForestDetectionAlgorithm(randomForestDetectionAlgorithm);
    return randomForestModel;
}
 
开发者ID:shlee89,项目名称:athena,代码行数:18,代码来源:RandomForestDistJob.java

示例2: OnlineFeatureHandler

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public OnlineFeatureHandler(FeatureConstraint featureConstraint,
                            DetectionModel detectionModel,
                            onlineMLEventListener onlineMLEventListener,
                            ControllerConnector controllerConnector) {
    this.featureConstraint = featureConstraint;
    this.detectionModel = detectionModel;
    setAthenaMLFeatureConfiguration(detectionModel.getAthenaMLFeatureConfiguration());

    if (detectionModel instanceof KMeansDetectionModel) {
        this.kMeansModel = (KMeansModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof GaussianMixtureDetectionModel) {
        this.gaussianMixtureModel = (GaussianMixtureModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof DecisionTreeDetectionModel) {
        this.decisionTreeModel = (DecisionTreeModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof NaiveBayesDetectionModel) {
        this.naiveBayesModel = (NaiveBayesModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof RandomForestDetectionModel) {
        this.randomForestModel = (RandomForestModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof GradientBoostedTreesDetectionModel) {
        this.gradientBoostedTreesModel = (GradientBoostedTreesModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof SVMDetectionModel) {
        this.svmModel = (SVMModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LogisticRegressionDetectionModel) {
        this.logisticRegressionModel = (LogisticRegressionModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LinearRegressionDetectionModel) {
        this.linearRegressionModel = (LinearRegressionModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LassoDetectionModel) {
        this.lassoModel = (LassoModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof RidgeRegressionDetectionModel) {
        this.ridgeRegressionModel = (RidgeRegressionModel) detectionModel.getDetectionModel();
    } else {
        //not supported ML model
        System.out.println("Not supported model");
    }

    this.eventDeliveryManager = new EventDeliveryManagerImpl(controllerConnector, new InternalAthenaFeatureEventListener());
    this.eventDeliveryManager.registerOnlineAthenaFeature(null, new QueryIdentifier(QUERY_IDENTIFIER), featureConstraint);
    this.onlineMLEventListener = onlineMLEventListener;
    System.out.println("Install handler!");
}
 
开发者ID:shlee89,项目名称:athena,代码行数:41,代码来源:OnlineFeatureHandler.java

示例3: predictorExampleCounts

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
/**
 * @param trainPointData data to run down trees
 * @param model random decision forest model to count on
 * @return map of predictor index to the number of training examples that reached a
 *  node whose decision is based on that feature. The index is among predictors, not all
 *  features, since there are fewer predictors than features. That is, the index will
 *  match the one used in the {@link RandomForestModel}.
 */
private static Map<Integer,Long> predictorExampleCounts(JavaRDD<LabeledPoint> trainPointData,
                                                        RandomForestModel model) {
  return trainPointData.mapPartitions(data -> {
      IntLongMap featureIndexCount = HashIntLongMaps.newMutableMap();
      data.forEachRemaining(datum -> {
        double[] featureVector = datum.features().toArray();
        for (DecisionTreeModel tree : model.trees()) {
          org.apache.spark.mllib.tree.model.Node node = tree.topNode();
          // This logic cloned from Node.predict:
          while (!node.isLeaf()) {
            Split split = node.split().get();
            int featureIndex = split.feature();
            // Count feature
            featureIndexCount.addValue(featureIndex, 1);
            node = nextNode(featureVector, node, split, featureIndex);
          }
        }
      });
      // Clone to avoid problem with Kryo serializing Koloboke
      return Collections.<Map<Integer,Long>>singleton(
          new HashMap<>(featureIndexCount)).iterator();
  }).reduce(RDFUpdate::merge);
}
 
开发者ID:oncewang,项目名称:oryx2,代码行数:32,代码来源:RDFUpdate.java

示例4: generateDecisionTreeWithPreprocessing

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public RandomForestModel generateDecisionTreeWithPreprocessing(JavaPairRDD<Object, BSONObject> mongoRDD,
                                                               AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                                                               RandomForestDetectionAlgorithm randomForestDetectionAlgorithm,
                                                               Marking marking,
                                                               RandomForestModelSummary randomForestModelSummary) {

    return generateKMeansModel(
            rddPreProcessing(mongoRDD, athenaMLFeatureConfiguration, randomForestModelSummary,
                    marking),
            randomForestDetectionAlgorithm, randomForestModelSummary
    );
}
 
开发者ID:shlee89,项目名称:athena,代码行数:13,代码来源:RandomForestDistJob.java

示例5: treeNodeExampleCounts

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
/**
 * @param trainPointData data to run down trees
 * @param model random decision forest model to count on
 * @return maps of node IDs to the count of training examples that reached that node, one
 *  per tree in the model
 * @see #predictorExampleCounts(JavaRDD,RandomForestModel)
 */
private static List<Map<Integer,Long>> treeNodeExampleCounts(JavaRDD<LabeledPoint> trainPointData,
                                                             RandomForestModel model) {
  return trainPointData.mapPartitions(data -> {
      DecisionTreeModel[] trees = model.trees();
      List<IntLongMap> treeNodeIDCounts = IntStream.range(0, trees.length).
          mapToObj(i -> HashIntLongMaps.newMutableMap()).collect(Collectors.toList());
      data.forEachRemaining(datum -> {
        double[] featureVector = datum.features().toArray();
        for (int i = 0; i < trees.length; i++) {
          DecisionTreeModel tree = trees[i];
          IntLongMap nodeIDCount = treeNodeIDCounts.get(i);
          org.apache.spark.mllib.tree.model.Node node = tree.topNode();
          // This logic cloned from Node.predict:
          while (!node.isLeaf()) {
            // Count node ID
            nodeIDCount.addValue(node.id(), 1);
            Split split = node.split().get();
            int featureIndex = split.feature();
            node = nextNode(featureVector, node, split, featureIndex);
          }
          nodeIDCount.addValue(node.id(), 1);
        }
      });
      return Collections.<List<Map<Integer,Long>>>singleton(
          treeNodeIDCounts.stream().map(HashMap::new).collect(Collectors.toList())).iterator();
    }
  ).reduce((a, b) -> {
      Preconditions.checkArgument(a.size() == b.size());
      for (int i = 0; i < a.size(); i++) {
        merge(a.get(i), b.get(i));
      }
      return a;
    });
}
 
开发者ID:oncewang,项目名称:oryx2,代码行数:42,代码来源:RDFUpdate.java

示例6: main

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public static void main(String[] args) {
        if (args.length < 3) {
            System.err.println(
                    "Usage: RandomForestMP <training_data> <test_data> <results>");
            System.exit(1);
        }
        String training_data_path = args[0];
        String test_data_path = args[1];
        String results_path = args[2];

        SparkConf sparkConf = new SparkConf().setAppName("RandomForestMP");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        final RandomForestModel model;

        Integer numClasses = 2;
        HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
        Integer numTrees = 3;
        String featureSubsetStrategy = "auto";
        String impurity = "gini";
        Integer maxDepth = 5;
        Integer maxBins = 32;
        Integer seed = 12345;

        // TODO

//        JavaRDD<LabeledPoint> results = test.map(new Function<Vector, LabeledPoint>() {
//            public LabeledPoint call(Vector points) {
//                return new LabeledPoint(model.predict(points), points);
//            }
//        });

//        results.saveAsTextFile(results_path);

        sc.stop();
    }
 
开发者ID:kgrodzicki,项目名称:cloud-computing-specialization,代码行数:36,代码来源:RandomForestMP.java

示例7: main

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public static void main(String args[]){

		SparkConf configuration = new SparkConf().setMaster("local[4]").setAppName("Any");
		JavaSparkContext sc = new JavaSparkContext(configuration);

		// Load and parse the data file.
		String input = "data/rf-data.txt";
		JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), input).toJavaRDD();
		// Split the data into training and test sets (30% held out for testing)
		JavaRDD<LabeledPoint>[] dataSplits = data.randomSplit(new double[]{0.7, 0.3});
		JavaRDD<LabeledPoint> trainingData = dataSplits[0];
		JavaRDD<LabeledPoint> testData = dataSplits[1];

		// Train a RandomForest model.
		Integer numClasses = 2;
		HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();//  Empty categoricalFeaturesInfo indicates all features are continuous.
		Integer numTrees = 3; // Use more in practice.
		String featureSubsetStrategy = "auto"; // Let the algorithm choose.
		String impurity = "gini";
		Integer maxDepth = 5;
		Integer maxBins = 32;
		Integer seed = 12345;

		final RandomForestModel rfModel = RandomForest.trainClassifier(trainingData, numClasses,
				categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
				seed);

		// Evaluate model on test instances and compute test error
		JavaPairRDD<Double, Double> label =
				testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
					public Tuple2<Double, Double> call(LabeledPoint p) {
						return new Tuple2<Double, Double>(rfModel.predict(p.features()), p.label());
					}
				});

		Double testError =
				1.0 * label.filter(new Function<Tuple2<Double, Double>, Boolean>() {
					public Boolean call(Tuple2<Double, Double> pl) {
						return !pl._1().equals(pl._2());
					}
				}).count() / testData.count();

		System.out.println("Test Error: " + testError);
		System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());
	}
 
开发者ID:PacktPublishing,项目名称:Java-Data-Science-Cookbook,代码行数:46,代码来源:RandomForestMlib.java

示例8: explain

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public ExplainerResults explain(String stringQuery) throws Exception {
  DataFrame inputData =
      SparkUtils.getInstance().getSQLContext().read().format("com.databricks.spark.csv")
          .option("inferSchema", "true").option("delimiter", this.expParams.getDelimiter())
          .option("header", String.valueOf(this.expParams.isColumnNameSpecified()))
          .load(this.expParams.getDataPath());

  final String impurity = this.expParams.getImpurity();
  final int maxDepth = this.expParams.getMaxDepth();
  final int maxBins = this.expParams.getMaxBins();
  final int numTrees = this.expParams.getNumTrees();
  final int seed = this.expParams.getSeed();
  final int numClasses = this.expParams.getNumClasses();
  final String featureSubsetStrategy = this.expParams.getFeatureSubsetStrategy();
  final String delimiter = Pattern.quote(this.expParams.getDelimiter());

  Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();

  JavaRDD<String> javardd = inputData.toJavaRDD().map(new Function<Row, String>() {
    private static final long serialVersionUID = 1L;

    public String call(Row row) {
      return row.mkString(expParams.getDelimiter());
    }
  });

  JavaRDD<LabeledPoint> data = ExplainerUtils.convertRDDStringToLabeledPoint(javardd, delimiter);

  final RandomForestModel model =
      RandomForest.trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees,
          featureSubsetStrategy, impurity, maxDepth, maxBins, seed);

  String splitter[] = stringQuery.split(delimiter);
  double[] features = new double[splitter.length];
  for (int i = 0; i < splitter.length; i++) {
    splitter[i] = splitter[i].trim();
    if (splitter[i].isEmpty()) {
      throw new Exception(this.getClass() + " : Value missing in " + i
          + " column in the given query \"" + stringQuery + "\"");
    }
    features[i] = Double.parseDouble(splitter[i].trim());
  }
  Vector featureVector = Vectors.dense(features);
  double labelToBeExplained = model.predict(featureVector);

  ExplainerResults expResult = explainerImpl(model, stringQuery, labelToBeExplained, inputData);
  return expResult;
}
 
开发者ID:zoho-labs,项目名称:Explainer,代码行数:49,代码来源:Explainer.java

示例9: validate

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public void validate(JavaPairRDD<Object, BSONObject> mongoRDD,
                     AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                     RandomForestDetectionModel randomForestDetectionModel,
                     RandomForestValidationSummary randomForestValidationSummary) {
    List<AthenaFeatureField> listOfTargetFeatures = athenaMLFeatureConfiguration.getListOfTargetFeatures();
    Map<AthenaFeatureField, Integer> weight = athenaMLFeatureConfiguration.getWeight();
    Marking marking = randomForestDetectionModel.getMarking();
    RandomForestModel model = (RandomForestModel) randomForestDetectionModel.getDetectionModel();
    Normalizer normalizer = new Normalizer();

    int numberOfTargetValue = listOfTargetFeatures.size();

    mongoRDD.foreach(new VoidFunction<Tuple2<Object, BSONObject>>() {
        public void call(Tuple2<Object, BSONObject> t) throws UnknownHostException {
            long start2 = System.nanoTime(); // <-- start
            BSONObject feature = (BSONObject) t._2().get(AthenaFeatureField.FEATURE);
            BSONObject idx = (BSONObject) t._2();
            int originLabel = marking.checkClassificationMarkingElements(idx,feature);

            double[] values = new double[numberOfTargetValue];
            for (int j = 0; j < numberOfTargetValue; j++) {
                values[j] = 0;
                if (feature.containsField(listOfTargetFeatures.get(j).getValue())) {
                    Object obj = feature.get(listOfTargetFeatures.get(j).getValue());
                    if (obj instanceof Long) {
                        values[j] = (Long) obj;
                    } else if (obj instanceof Double) {
                        values[j] = (Double) obj;
                    } else if (obj instanceof Boolean) {
                        values[j] = (Boolean) obj ? 1 : 0;
                    } else {
                        return;
                    }

                    //check weight
                    if (weight.containsKey(listOfTargetFeatures.get(j))) {
                        values[j] *= weight.get(listOfTargetFeatures.get(j));
                    }

                    //check absolute
                    if (athenaMLFeatureConfiguration.isAbsolute()){
                        values[j] = Math.abs(values[j]);
                    }
                }
            }

            Vector normedForVal;
            if (athenaMLFeatureConfiguration.isNormalization()) {
                normedForVal = normalizer.transform(Vectors.dense(values));
            } else {
                normedForVal = Vectors.dense(values);
            }

            LabeledPoint p = new LabeledPoint(originLabel,normedForVal);

            int validatedLabel = (int) model.predict(p.features());


            randomForestValidationSummary.updateSummary(validatedLabel,idx,feature);

            long end2 = System.nanoTime();
            long result2 = end2 - start2;
            randomForestValidationSummary.addTotalNanoSeconds(result2);
        }
    });
    randomForestValidationSummary.getAverageNanoSeconds();
    randomForestValidationSummary.setRandomForestDetectionAlgorithm((RandomForestDetectionAlgorithm) randomForestDetectionModel.getDetectionAlgorithm());
}
 
开发者ID:shlee89,项目名称:athena,代码行数:69,代码来源:RandomForestDistJob.java

示例10: setRandomForestModel

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public void setRandomForestModel(RandomForestModel randomForestModel) {
    this.randomForestModel = randomForestModel;
}
 
开发者ID:shlee89,项目名称:athena,代码行数:4,代码来源:RandomForestDetectionModel.java

示例11: generateRandomForestAthenaDetectionModel

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public RandomForestDetectionModel generateRandomForestAthenaDetectionModel(JavaSparkContext sc,
                                                                           FeatureConstraint featureConstraint,
                                                                           AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                                                                           DetectionAlgorithm detectionAlgorithm,
                                                                           Indexing indexing,
                                                                           Marking marking) {
    RandomForestModelSummary randomForestModelSummary = new RandomForestModelSummary(
            sc.sc(), indexing, marking);

    long start = System.nanoTime(); // <-- start

    RandomForestDetectionAlgorithm randomForestDetectionAlgorithm = (RandomForestDetectionAlgorithm) detectionAlgorithm;

    RandomForestDetectionModel randomForestDetectionModel = new RandomForestDetectionModel();

    randomForestDetectionModel.setRandomForestDetectionAlgorithm(randomForestDetectionAlgorithm);
    randomForestModelSummary.setRandomForestDetectionAlgorithm(randomForestDetectionAlgorithm);
    randomForestDetectionModel.setFeatureConstraint(featureConstraint);
    randomForestDetectionModel.setAthenaMLFeatureConfiguration(athenaMLFeatureConfiguration);
    randomForestDetectionModel.setIndexing(indexing);
    randomForestDetectionModel.setMarking(marking);

    JavaPairRDD<Object, BSONObject> mongoRDD;
    mongoRDD = sc.newAPIHadoopRDD(
            mongodbConfig,            // Configuration
            MongoInputFormat.class,   // InputFormat: read from a live cluster.
            Object.class,             // Key class
            BSONObject.class          // Value class
    );

    RandomForestDistJob randomForestDistJob = new RandomForestDistJob();

    RandomForestModel decisionTreeModel = randomForestDistJob.generateDecisionTreeWithPreprocessing(mongoRDD,
            athenaMLFeatureConfiguration, randomForestDetectionAlgorithm, marking, randomForestModelSummary);


    randomForestDetectionModel.setRandomForestModel(decisionTreeModel);
    long end = System.nanoTime(); // <-- start
    long time = end - start;
    randomForestModelSummary.setTotalLearningTime(time);
    randomForestDetectionModel.setClassificationModelSummary(randomForestModelSummary);

    return randomForestDetectionModel;
}
 
开发者ID:shlee89,项目名称:athena,代码行数:45,代码来源:MachineLearningManagerImpl.java

示例12: buildModel

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
@Override
public PMML buildModel(JavaSparkContext sparkContext,
                       JavaRDD<String> trainData,
                       List<?> hyperParameters,
                       Path candidatePath) {

  int maxSplitCandidates = (Integer) hyperParameters.get(0);
  int maxDepth = (Integer) hyperParameters.get(1);
  String impurity = (String) hyperParameters.get(2);
  Preconditions.checkArgument(maxSplitCandidates >= 2,
                              "max-split-candidates must be at least 2");
  Preconditions.checkArgument(maxDepth > 0,
                              "max-depth must be at least 1");

  JavaRDD<String[]> parsedRDD = trainData.map(MLFunctions.PARSE_FN);
  CategoricalValueEncodings categoricalValueEncodings =
      new CategoricalValueEncodings(getDistinctValues(parsedRDD));
  JavaRDD<LabeledPoint> trainPointData =
      parseToLabeledPointRDD(parsedRDD, categoricalValueEncodings);

  Map<Integer,Integer> categoryInfo = categoricalValueEncodings.getCategoryCounts();
  categoryInfo.remove(inputSchema.getTargetFeatureIndex()); // Don't specify target count
  // Need to translate indices to predictor indices
  Map<Integer,Integer> categoryInfoByPredictor = new HashMap<>(categoryInfo.size());
  categoryInfo.forEach((k, v) -> categoryInfoByPredictor.put(inputSchema.featureToPredictorIndex(k), v));

  int seed = RandomManager.getRandom().nextInt();

  RandomForestModel model;
  if (inputSchema.isClassification()) {
    int numTargetClasses =
        categoricalValueEncodings.getValueCount(inputSchema.getTargetFeatureIndex());
    model = RandomForest.trainClassifier(trainPointData,
                                         numTargetClasses,
                                         categoryInfoByPredictor,
                                         numTrees,
                                         "auto",
                                         impurity,
                                         maxDepth,
                                         maxSplitCandidates,
                                         seed);
  } else {
    model = RandomForest.trainRegressor(trainPointData,
                                        categoryInfoByPredictor,
                                        numTrees,
                                        "auto",
                                        impurity,
                                        maxDepth,
                                        maxSplitCandidates,
                                        seed);
  }

  List<Map<Integer,Long>> treeNodeIDCounts = treeNodeExampleCounts(trainPointData, model);
  Map<Integer,Long> predictorIndexCounts = predictorExampleCounts(trainPointData, model);

  return rdfModelToPMML(model,
                        categoricalValueEncodings,
                        maxDepth,
                        maxSplitCandidates,
                        impurity,
                        treeNodeIDCounts,
                        predictorIndexCounts);
}
 
开发者ID:oncewang,项目名称:oryx2,代码行数:64,代码来源:RDFUpdate.java

示例13: rdfModelToPMML

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
private PMML rdfModelToPMML(RandomForestModel rfModel,
                            CategoricalValueEncodings categoricalValueEncodings,
                            int maxDepth,
                            int maxSplitCandidates,
                            String impurity,
                            List<Map<Integer,Long>> nodeIDCounts,
                            Map<Integer,Long> predictorIndexCounts) {

  boolean classificationTask = rfModel.algo().equals(Algo.Classification());
  Preconditions.checkState(classificationTask == inputSchema.isClassification());

  DecisionTreeModel[] trees = rfModel.trees();

  Model model;
  if (trees.length == 1) {
    model = toTreeModel(trees[0], categoricalValueEncodings, nodeIDCounts.get(0));
  } else {
    MiningModel miningModel = new MiningModel();
    model = miningModel;
    Segmentation.MultipleModelMethod multipleModelMethodType = classificationTask ?
        Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE :
        Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE;
    List<Segment> segments = new ArrayList<>(trees.length);
    for (int treeID = 0; treeID < trees.length; treeID++) {
      TreeModel treeModel =
          toTreeModel(trees[treeID], categoricalValueEncodings, nodeIDCounts.get(treeID));
      segments.add(new Segment()
           .setId(Integer.toString(treeID))
           .setPredicate(new True())
           .setModel(treeModel)
           .setWeight(1.0)); // No weights in MLlib impl now
    }
    miningModel.setSegmentation(new Segmentation(multipleModelMethodType, segments));
  }

  model.setMiningFunction(classificationTask ?
                          MiningFunction.CLASSIFICATION :
                          MiningFunction.REGRESSION);

  double[] importances = countsToImportances(predictorIndexCounts);
  model.setMiningSchema(AppPMMLUtils.buildMiningSchema(inputSchema, importances));
  DataDictionary dictionary =
      AppPMMLUtils.buildDataDictionary(inputSchema, categoricalValueEncodings);

  PMML pmml = PMMLUtils.buildSkeletonPMML();
  pmml.setDataDictionary(dictionary);
  pmml.addModels(model);

  AppPMMLUtils.addExtension(pmml, "maxDepth", maxDepth);
  AppPMMLUtils.addExtension(pmml, "maxSplitCandidates", maxSplitCandidates);
  AppPMMLUtils.addExtension(pmml, "impurity", impurity);

  return pmml;
}
 
开发者ID:oncewang,项目名称:oryx2,代码行数:55,代码来源:RDFUpdate.java

示例14: MLRandomForestModel

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
public MLRandomForestModel(RandomForestModel model) {
    this.model = model;
}
 
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:4,代码来源:MLRandomForestModel.java

示例15: readExternal

import org.apache.spark.mllib.tree.model.RandomForestModel; //导入依赖的package包/类
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {

    model = (RandomForestModel) in.readObject();
}
 
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:6,代码来源:MLRandomForestModel.java


注:本文中的org.apache.spark.mllib.tree.model.RandomForestModel类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。