本文整理汇总了Java中org.apache.spark.mllib.classification.LogisticRegressionModel类的典型用法代码示例。如果您正苦于以下问题:Java LogisticRegressionModel类的具体用法?Java LogisticRegressionModel怎么用?Java LogisticRegressionModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
LogisticRegressionModel类属于org.apache.spark.mllib.classification包,在下文中一共展示了LogisticRegressionModel类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: predictForMetrics
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public JavaRDD<Tuple2<Object, Object>> predictForMetrics(String modelName, T model, JavaRDD<LabeledPoint> data, int numClasses){
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = null;
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = (LogisticRegressionModel) model;
if(numClasses==2){
lrmodel.clearThreshold();
}
//Predict
predictionAndLabels = PredictUnit.predictForMetrics_LogisticRegressionModel(lrmodel, data);
}
else if(modelName.equals("SVMModel")){
SVMModel svmmodel = (SVMModel) model;
if(numClasses==2){
svmmodel.clearThreshold();
}
//Predict
predictionAndLabels = PredictUnit.predictForMetrics_SVMModel(svmmodel, data);
}
else if(modelName.equals("NaiveBayesModel")){
NaiveBayesModel bayesmodel = (NaiveBayesModel) model;
//Predict
predictionAndLabels = PredictUnit.predictForMetrics_NaiveBayesModel(bayesmodel, data);
}
return predictionAndLabels;
}
示例2: predictForOutput
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public JavaRDD<Tuple2<Object, Object>> predictForOutput(String modelName, T model, JavaRDD<LabeledPoint> data, int numClasses, double threshold){
JavaRDD<Tuple2<Object, Object>> FeaturesAndPrediction = null;
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = (LogisticRegressionModel) model;
if(numClasses==2){
lrmodel.setThreshold(threshold);
}
//Predict
FeaturesAndPrediction = PredictUnit.predictForOutput_LogisticRegressionModel(lrmodel, data);
}
else if(modelName.equals("SVMModel")){
SVMModel svmmodel = (SVMModel) model;
if(numClasses==2){
svmmodel.setThreshold(threshold);
}
//Predict
FeaturesAndPrediction = PredictUnit.predictForOutput_SVMModel(svmmodel, data);
}
else if(modelName.equals("NaiveBayesModel")){
NaiveBayesModel bayesmodel = (NaiveBayesModel) model;
//Predict
FeaturesAndPrediction = PredictUnit.predictForOutput_NaiveBayesModel(bayesmodel, data);
}
return FeaturesAndPrediction;
}
示例3: getModelInfo
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) {
final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray());
logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get());
Set<String> inputKeys = new LinkedHashSet<String>();
inputKeys.add("features");
logisticRegressionModelInfo.setInputKeys(inputKeys);
Set<String> outputKeys = new LinkedHashSet<String>();
outputKeys.add("prediction");
outputKeys.add("probability");
logisticRegressionModelInfo.setOutputKeys(outputKeys);
return logisticRegressionModelInfo;
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:21,代码来源:LogisticRegressionModelInfoAdapter.java
示例4: predictForOutput_LogisticRegressionModel
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public static JavaRDD<Tuple2<Object, Object>> predictForOutput_LogisticRegressionModel(LogisticRegressionModel model, JavaRDD<LabeledPoint> data){
JavaRDD<Tuple2<Object, Object>> FeaturesAndPrediction = data.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
private static final long serialVersionUID = 1L;
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(p.features(), prediction);
}
}
);
return FeaturesAndPrediction;
}
示例5: PredictWithModel
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public PredictWithModel(String modelName, String modelPath, String testFile, int numClasses, int minPartition, double threshold, SparkContext sc){
this.numClasses = numClasses;
this.threshold = threshold;
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = LogisticRegressionModel.load(sc, modelPath);
this.model = (T)(Object) lrmodel;
}
else if(modelName.equals("SVMModel")){
SVMModel svmmodel = SVMModel.load(sc, modelPath);
this.model = (T)(Object) svmmodel;
}
else if(modelName.equals("NaiveBayesModel")){
NaiveBayesModel bayesmodel = NaiveBayesModel.load(sc, modelPath);
this.model = (T)(Object) bayesmodel;
}
//Load testing data
LoadProcess loadProcess = new LoadProcess(sc, minPartition);
testingData = loadProcess.load(testFile, "Vector");
testingData.cache();
}
示例6: trainWithLBFGS
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithLBFGS(){
//Train the model
if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = new LogisticRegressionWithLBFGS()
.setNumClasses(numClasses)
.run(trainingData.rdd());
System.out.println("\n--------------------------------------\n weights: " + lrmodel.weights());
System.out.println("--------------------------------------\n");
this.model = (T)(Object) lrmodel;
}
//Evalute the trained model
EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
evalProcess.evalute(numClasses);
return model;
}
示例7: trainWithSGD
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithSGD(int numIterations){
//Train the model
if(modelName.equals("SVMModel")){
SVMModel svmmodel = SVMWithSGD.train(trainingData.rdd(), numIterations);
this.model = (T)(Object) svmmodel;
}
else if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = LogisticRegressionWithSGD.train(trainingData.rdd(), numIterations);
this.model = (T)(Object) lrmodel;
}
//Evalute the trained model
EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
evalProcess.evalute(numClasses);
return model;
}
示例8: generateDecisionTreeWithPreprocessing
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public LogisticRegressionModel generateDecisionTreeWithPreprocessing(JavaPairRDD<Object, BSONObject> mongoRDD,
AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
LogisticRegressionDetectionAlgorithm logisticRegressionDetectionAlgorithm,
Marking marking,
LogisticRegressionModelSummary logisticRegressionModelSummary) {
return generateKMeansModel(
rddPreProcessing(mongoRDD, athenaMLFeatureConfiguration, logisticRegressionModelSummary,
marking),
logisticRegressionDetectionAlgorithm, logisticRegressionModelSummary
);
}
示例9: OnlineFeatureHandler
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public OnlineFeatureHandler(FeatureConstraint featureConstraint,
DetectionModel detectionModel,
onlineMLEventListener onlineMLEventListener,
ControllerConnector controllerConnector) {
this.featureConstraint = featureConstraint;
this.detectionModel = detectionModel;
setAthenaMLFeatureConfiguration(detectionModel.getAthenaMLFeatureConfiguration());
if (detectionModel instanceof KMeansDetectionModel) {
this.kMeansModel = (KMeansModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof GaussianMixtureDetectionModel) {
this.gaussianMixtureModel = (GaussianMixtureModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof DecisionTreeDetectionModel) {
this.decisionTreeModel = (DecisionTreeModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof NaiveBayesDetectionModel) {
this.naiveBayesModel = (NaiveBayesModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof RandomForestDetectionModel) {
this.randomForestModel = (RandomForestModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof GradientBoostedTreesDetectionModel) {
this.gradientBoostedTreesModel = (GradientBoostedTreesModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof SVMDetectionModel) {
this.svmModel = (SVMModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LogisticRegressionDetectionModel) {
this.logisticRegressionModel = (LogisticRegressionModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LinearRegressionDetectionModel) {
this.linearRegressionModel = (LinearRegressionModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LassoDetectionModel) {
this.lassoModel = (LassoModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof RidgeRegressionDetectionModel) {
this.ridgeRegressionModel = (RidgeRegressionModel) detectionModel.getDetectionModel();
} else {
//not supported ML model
System.out.println("Not supported model");
}
this.eventDeliveryManager = new EventDeliveryManagerImpl(controllerConnector, new InternalAthenaFeatureEventListener());
this.eventDeliveryManager.registerOnlineAthenaFeature(null, new QueryIdentifier(QUERY_IDENTIFIER), featureConstraint);
this.onlineMLEventListener = onlineMLEventListener;
System.out.println("Install handler!");
}
示例10: getModelInfo
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel) {
final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray());
logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get());
Set<String> inputKeys = new LinkedHashSet<String>();
inputKeys.add("features");
logisticRegressionModelInfo.setInputKeys(inputKeys);
Set<String> outputKeys = new LinkedHashSet<String>();
outputKeys.add("prediction");
outputKeys.add("probability");
logisticRegressionModelInfo.setOutputKeys(outputKeys);
return logisticRegressionModelInfo;
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:21,代码来源:LogisticRegressionModelInfoAdapter.java
示例11: shouldExportAndImportCorrectly
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel);
//Import it back
LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);
//check if they are exactly equal with respect to their fields
//it maybe edge cases eg. order of elements in the list is changed
assertEquals(lrmodel.intercept(), importedModel.getIntercept(), 0.01);
assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), 0.01);
assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), 0.01);
assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), 0.01);
for (int i = 0; i < importedModel.getNumFeatures(); i++)
assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], 0.01);
}
示例12: shouldExportAndImportCorrectly
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel, null);
//Import it back
LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);
//check if they are exactly equal with respect to their fields
//it maybe edge cases eg. order of elements in the list is changed
assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON);
assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON);
assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON);
assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), EPSILON);
for (int i = 0; i < importedModel.getNumFeatures(); i++)
assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON);
}
示例13: predictForMetrics_LogisticRegressionModel
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public static JavaRDD<Tuple2<Object, Object>> predictForMetrics_LogisticRegressionModel(LogisticRegressionModel model, JavaRDD<LabeledPoint> data){
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = data.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
private static final long serialVersionUID = 1L;
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(prediction, p.label());
}
}
);
return predictionAndLabels;
}
示例14: generateKMeansModel
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
public LogisticRegressionModel generateKMeansModel(JavaRDD<LabeledPoint> parsedData,
LogisticRegressionDetectionAlgorithm logisticRegressionDetectionAlgorithm,
LogisticRegressionModelSummary logisticRegressionModelSummary) {
LogisticRegressionModel model
= new LogisticRegressionWithLBFGS()
.setNumClasses(logisticRegressionDetectionAlgorithm.getNumClasses())
.run(parsedData.rdd());
logisticRegressionModelSummary.setLogisticRegressionDetectionAlgorithm(logisticRegressionDetectionAlgorithm);
return model;
}
示例15: testLogisticRegression
import org.apache.spark.mllib.classification.LogisticRegressionModel; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
//prepare data
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel);
//Import and get Transformer
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
//validate predictions
List<LabeledPoint> testPoints = trainingData.collect();
for (LabeledPoint i : testPoints) {
Vector v = i.features();
double actual = lrmodel.predict(v);
Map<String, Object> data = new HashMap<String, Object>();
data.put("features", v.toArray());
transformer.transform(data);
double predicted = (double) data.get("prediction");
assertEquals(actual, predicted, 0.01);
}
}