当前位置: 首页>>代码示例>>Java>>正文


Java LogisticRegressionModel类代码示例

本文整理汇总了Java中org.apache.spark.mllib.classification.LogisticRegressionModel的典型用法代码示例。如果您正苦于以下问题:Java LogisticRegressionModel类的具体用法?Java LogisticRegressionModel怎么用?Java LogisticRegressionModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


LogisticRegressionModel类属于org.apache.spark.mllib.classification包,在下文中一共展示了LogisticRegressionModel类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: predictForMetrics

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public JavaRDD<Tuple2<Object, Object>> predictForMetrics(String modelName, T model, JavaRDD<LabeledPoint> data, int numClasses){
    JavaRDD<Tuple2<Object, Object>> predictionAndLabels = null;
    if(modelName.equals("LogisticRegressionModel")){
      LogisticRegressionModel lrmodel = (LogisticRegressionModel) model;      
      if(numClasses==2){
        lrmodel.clearThreshold();
      } 
      //Predict
      predictionAndLabels = PredictUnit.predictForMetrics_LogisticRegressionModel(lrmodel, data);
    }
    else if(modelName.equals("SVMModel")){
      SVMModel svmmodel = (SVMModel) model;      
      if(numClasses==2){
        svmmodel.clearThreshold();
      }     
      //Predict
      predictionAndLabels = PredictUnit.predictForMetrics_SVMModel(svmmodel, data);
    }
    else if(modelName.equals("NaiveBayesModel")){
      NaiveBayesModel bayesmodel = (NaiveBayesModel) model;      
      //Predict
      predictionAndLabels = PredictUnit.predictForMetrics_NaiveBayesModel(bayesmodel, data);
    }
    return predictionAndLabels;
}
 
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:26,代码来源:PredictUnit.java

示例2: predictForOutput

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public JavaRDD<Tuple2<Object, Object>> predictForOutput(String modelName, T model, JavaRDD<LabeledPoint> data, int numClasses, double threshold){
    JavaRDD<Tuple2<Object, Object>> FeaturesAndPrediction = null;
    if(modelName.equals("LogisticRegressionModel")){
      LogisticRegressionModel lrmodel = (LogisticRegressionModel) model; 
      if(numClasses==2){
        lrmodel.setThreshold(threshold);
      }  
      //Predict
      FeaturesAndPrediction = PredictUnit.predictForOutput_LogisticRegressionModel(lrmodel, data);
    }
    else if(modelName.equals("SVMModel")){
      SVMModel svmmodel = (SVMModel) model;     
      if(numClasses==2){
        svmmodel.setThreshold(threshold);
      }
      //Predict
      FeaturesAndPrediction = PredictUnit.predictForOutput_SVMModel(svmmodel, data);
    }
    else if(modelName.equals("NaiveBayesModel")){
      NaiveBayesModel bayesmodel = (NaiveBayesModel) model;    
      //Predict
      FeaturesAndPrediction = PredictUnit.predictForOutput_NaiveBayesModel(bayesmodel, data);
    }
    
    return FeaturesAndPrediction;
}
 
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:27,代码来源:PredictUnit.java

示例3: getModelInfo

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) {
    final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
    logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray());
    logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
    logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
    logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
    logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add("features");
    logisticRegressionModelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add("prediction");
    outputKeys.add("probability");
    logisticRegressionModelInfo.setOutputKeys(outputKeys);

    return logisticRegressionModelInfo;
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:21,代码来源:LogisticRegressionModelInfoAdapter.java

示例4: predictForOutput_LogisticRegressionModel

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public static JavaRDD<Tuple2<Object, Object>> predictForOutput_LogisticRegressionModel(LogisticRegressionModel model, JavaRDD<LabeledPoint> data){
    JavaRDD<Tuple2<Object, Object>> FeaturesAndPrediction = data.map(
      new Function<LabeledPoint, Tuple2<Object, Object>>() {
        private static final long serialVersionUID = 1L;
        public Tuple2<Object, Object> call(LabeledPoint p) {
          Double prediction = model.predict(p.features());
          return new Tuple2<Object, Object>(p.features(), prediction);
        }
      }
    );
    return FeaturesAndPrediction;    
}
 
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:13,代码来源:PredictUnit.java

示例5: PredictWithModel

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public PredictWithModel(String modelName, String modelPath, String testFile, int numClasses, int minPartition, double threshold, SparkContext sc){
    this.numClasses = numClasses;
    this.threshold = threshold;

    if(modelName.equals("LogisticRegressionModel")){
      LogisticRegressionModel lrmodel = LogisticRegressionModel.load(sc, modelPath);
      this.model = (T)(Object) lrmodel;
    }
    else if(modelName.equals("SVMModel")){
      SVMModel svmmodel = SVMModel.load(sc, modelPath);
      this.model = (T)(Object) svmmodel;
    }
    else if(modelName.equals("NaiveBayesModel")){
      NaiveBayesModel bayesmodel = NaiveBayesModel.load(sc, modelPath);
      this.model = (T)(Object) bayesmodel;
    }
    
    //Load testing data
    LoadProcess loadProcess = new LoadProcess(sc, minPartition);    
    testingData = loadProcess.load(testFile, "Vector");  
    testingData.cache();    
}
 
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:24,代码来源:PredictWithModel.java

示例6: trainWithLBFGS

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithLBFGS(){
    //Train the model
    if(modelName.equals("LogisticRegressionModel")){
      LogisticRegressionModel lrmodel = new LogisticRegressionWithLBFGS()
      .setNumClasses(numClasses)
      .run(trainingData.rdd());  

      System.out.println("\n--------------------------------------\n weights: " + lrmodel.weights());
      System.out.println("--------------------------------------\n");


      this.model = (T)(Object) lrmodel;
    } 
    
    //Evalute the trained model      
    EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
    evalProcess.evalute(numClasses);
    return model;
}
 
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:21,代码来源:TrainModel.java

示例7: trainWithSGD

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithSGD(int numIterations){    
    //Train the model
    if(modelName.equals("SVMModel")){
      SVMModel svmmodel = SVMWithSGD.train(trainingData.rdd(), numIterations);
      this.model = (T)(Object) svmmodel;
    } 
    else if(modelName.equals("LogisticRegressionModel")){
      LogisticRegressionModel lrmodel = LogisticRegressionWithSGD.train(trainingData.rdd(), numIterations);
      this.model = (T)(Object) lrmodel;
    } 

    //Evalute the trained model      
    EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
    evalProcess.evalute(numClasses);
  return model;
}
 
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:18,代码来源:TrainModel.java

示例8: generateDecisionTreeWithPreprocessing

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public LogisticRegressionModel generateDecisionTreeWithPreprocessing(JavaPairRDD<Object, BSONObject> mongoRDD,
                                                               AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                                                               LogisticRegressionDetectionAlgorithm logisticRegressionDetectionAlgorithm,
                                                               Marking marking,
                                                               LogisticRegressionModelSummary logisticRegressionModelSummary) {

    return generateKMeansModel(
            rddPreProcessing(mongoRDD, athenaMLFeatureConfiguration, logisticRegressionModelSummary,
                    marking),
            logisticRegressionDetectionAlgorithm, logisticRegressionModelSummary
    );
}
 
开发者ID:shlee89,项目名称:athena,代码行数:13,代码来源:LogisticRegressionDistJob.java

示例9: OnlineFeatureHandler

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public OnlineFeatureHandler(FeatureConstraint featureConstraint,
                            DetectionModel detectionModel,
                            onlineMLEventListener onlineMLEventListener,
                            ControllerConnector controllerConnector) {
    this.featureConstraint = featureConstraint;
    this.detectionModel = detectionModel;
    setAthenaMLFeatureConfiguration(detectionModel.getAthenaMLFeatureConfiguration());

    if (detectionModel instanceof KMeansDetectionModel) {
        this.kMeansModel = (KMeansModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof GaussianMixtureDetectionModel) {
        this.gaussianMixtureModel = (GaussianMixtureModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof DecisionTreeDetectionModel) {
        this.decisionTreeModel = (DecisionTreeModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof NaiveBayesDetectionModel) {
        this.naiveBayesModel = (NaiveBayesModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof RandomForestDetectionModel) {
        this.randomForestModel = (RandomForestModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof GradientBoostedTreesDetectionModel) {
        this.gradientBoostedTreesModel = (GradientBoostedTreesModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof SVMDetectionModel) {
        this.svmModel = (SVMModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LogisticRegressionDetectionModel) {
        this.logisticRegressionModel = (LogisticRegressionModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LinearRegressionDetectionModel) {
        this.linearRegressionModel = (LinearRegressionModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LassoDetectionModel) {
        this.lassoModel = (LassoModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof RidgeRegressionDetectionModel) {
        this.ridgeRegressionModel = (RidgeRegressionModel) detectionModel.getDetectionModel();
    } else {
        //not supported ML model
        System.out.println("Not supported model");
    }

    this.eventDeliveryManager = new EventDeliveryManagerImpl(controllerConnector, new InternalAthenaFeatureEventListener());
    this.eventDeliveryManager.registerOnlineAthenaFeature(null, new QueryIdentifier(QUERY_IDENTIFIER), featureConstraint);
    this.onlineMLEventListener = onlineMLEventListener;
    System.out.println("Install handler!");
}
 
开发者ID:shlee89,项目名称:athena,代码行数:41,代码来源:OnlineFeatureHandler.java

示例10: getModelInfo

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel) {
    final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
    logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray());
    logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
    logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
    logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
    logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add("features");
    logisticRegressionModelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add("prediction");
    outputKeys.add("probability");
    logisticRegressionModelInfo.setOutputKeys(outputKeys);

    return logisticRegressionModelInfo;
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:21,代码来源:LogisticRegressionModelInfoAdapter.java

示例11: shouldExportAndImportCorrectly

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel);

    //Import it back
    LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);

    //check if they are exactly equal with respect to their fields
    //it maybe edge cases eg. order of elements in the list is changed
    assertEquals(lrmodel.intercept(), importedModel.getIntercept(), 0.01);
    assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), 0.01);
    assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), 0.01);
    assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), 0.01);
    for (int i = 0; i < importedModel.getNumFeatures(); i++)
        assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], 0.01);

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java

示例12: shouldExportAndImportCorrectly

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, null);

    //Import it back
    LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);

    //check if they are exactly equal with respect to their fields
    //it maybe edge cases eg. order of elements in the list is changed
    assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON);
    assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON);
    assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON);
    assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), EPSILON);
    for (int i = 0; i < importedModel.getNumFeatures(); i++)
        assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON);

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java

示例13: predictForMetrics_LogisticRegressionModel

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public static JavaRDD<Tuple2<Object, Object>> predictForMetrics_LogisticRegressionModel(LogisticRegressionModel model, JavaRDD<LabeledPoint> data){
    JavaRDD<Tuple2<Object, Object>> predictionAndLabels = data.map(
      new Function<LabeledPoint, Tuple2<Object, Object>>() {
        private static final long serialVersionUID = 1L;
        public Tuple2<Object, Object> call(LabeledPoint p) {
          Double prediction = model.predict(p.features());
          return new Tuple2<Object, Object>(prediction, p.label());
        }
      }
    ); 
    return predictionAndLabels;
}
 
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:13,代码来源:PredictUnit.java

示例14: generateKMeansModel

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public LogisticRegressionModel generateKMeansModel(JavaRDD<LabeledPoint> parsedData,
                                                   LogisticRegressionDetectionAlgorithm logisticRegressionDetectionAlgorithm,
                                                   LogisticRegressionModelSummary logisticRegressionModelSummary) {
    LogisticRegressionModel model
            = new LogisticRegressionWithLBFGS()
            .setNumClasses(logisticRegressionDetectionAlgorithm.getNumClasses())
            .run(parsedData.rdd());

    logisticRegressionModelSummary.setLogisticRegressionDetectionAlgorithm(logisticRegressionDetectionAlgorithm);
    return model;
}
 
开发者ID:shlee89,项目名称:athena,代码行数:12,代码来源:LogisticRegressionDistJob.java

示例15: testLogisticRegression

import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = trainingData.collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, 0.01);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:30,代码来源:LogisticRegressionBridgeTest.java


注:本文中的org.apache.spark.mllib.classification.LogisticRegressionModel类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。