當前位置: 首頁>>代碼示例>>Java>>正文


Java LassoModel類代碼示例

本文整理匯總了Java中org.apache.spark.mllib.regression.LassoModel的典型用法代碼示例。如果您正苦於以下問題:Java LassoModel類的具體用法?Java LassoModel怎麽用?Java LassoModel使用的例子?那麽, 這裏精選的類代碼示例或許可以為您提供幫助。


LassoModel類屬於org.apache.spark.mllib.regression包,在下文中一共展示了LassoModel類的13個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Java代碼示例。

示例1: generateKMeansModel

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public LassoModel generateKMeansModel(JavaRDD<LabeledPoint> parsedData,
                                                 LassoDetectionAlgorithm lassoDetectionAlgorithm,
                                                 LassoModelSummary lassoModelSummary) {
    LassoModel model;


    if (lassoDetectionAlgorithm.getMiniBatchFraction() != -1) {
        model = LassoWithSGD.train(parsedData.rdd(),
                lassoDetectionAlgorithm.getNumIterations(),
                lassoDetectionAlgorithm.getStepSize(),
                lassoDetectionAlgorithm.getRegParam(),
                lassoDetectionAlgorithm.getMiniBatchFraction());
    } else if (lassoDetectionAlgorithm.getRegParam() != -1) {
        model = LassoWithSGD.train(parsedData.rdd(),
                lassoDetectionAlgorithm.getNumIterations(),
                lassoDetectionAlgorithm.getStepSize(),
                lassoDetectionAlgorithm.getRegParam());
    }else{
        model = LassoWithSGD.train(parsedData.rdd(),
                lassoDetectionAlgorithm.getNumIterations());
    }


    lassoModelSummary.setLassoDetectionAlgorithm(lassoDetectionAlgorithm);
    return model;
}
 
開發者ID:shlee89,項目名稱:athena,代碼行數:27,代碼來源:LassoDistJob.java

示例2: OnlineFeatureHandler

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public OnlineFeatureHandler(FeatureConstraint featureConstraint,
                            DetectionModel detectionModel,
                            onlineMLEventListener onlineMLEventListener,
                            ControllerConnector controllerConnector) {
    this.featureConstraint = featureConstraint;
    this.detectionModel = detectionModel;
    setAthenaMLFeatureConfiguration(detectionModel.getAthenaMLFeatureConfiguration());

    if (detectionModel instanceof KMeansDetectionModel) {
        this.kMeansModel = (KMeansModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof GaussianMixtureDetectionModel) {
        this.gaussianMixtureModel = (GaussianMixtureModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof DecisionTreeDetectionModel) {
        this.decisionTreeModel = (DecisionTreeModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof NaiveBayesDetectionModel) {
        this.naiveBayesModel = (NaiveBayesModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof RandomForestDetectionModel) {
        this.randomForestModel = (RandomForestModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof GradientBoostedTreesDetectionModel) {
        this.gradientBoostedTreesModel = (GradientBoostedTreesModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof SVMDetectionModel) {
        this.svmModel = (SVMModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LogisticRegressionDetectionModel) {
        this.logisticRegressionModel = (LogisticRegressionModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LinearRegressionDetectionModel) {
        this.linearRegressionModel = (LinearRegressionModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof LassoDetectionModel) {
        this.lassoModel = (LassoModel) detectionModel.getDetectionModel();
    } else if (detectionModel instanceof RidgeRegressionDetectionModel) {
        this.ridgeRegressionModel = (RidgeRegressionModel) detectionModel.getDetectionModel();
    } else {
        //not supported ML model
        System.out.println("Not supported model");
    }

    this.eventDeliveryManager = new EventDeliveryManagerImpl(controllerConnector, new InternalAthenaFeatureEventListener());
    this.eventDeliveryManager.registerOnlineAthenaFeature(null, new QueryIdentifier(QUERY_IDENTIFIER), featureConstraint);
    this.onlineMLEventListener = onlineMLEventListener;
    System.out.println("Install handler!");
}
 
開發者ID:shlee89,項目名稱:athena,代碼行數:41,代碼來源:OnlineFeatureHandler.java

示例3: generateDecisionTreeWithPreprocessing

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public LassoModel generateDecisionTreeWithPreprocessing(JavaPairRDD<Object, BSONObject> mongoRDD,
                                                                   AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                                                                   LassoDetectionAlgorithm lassoDetectionAlgorithm,
                                                                   Marking marking,
                                                                   LassoModelSummary lassoModelSummary) {

    return generateKMeansModel(
            rddPreProcessing(mongoRDD, athenaMLFeatureConfiguration, lassoModelSummary,
                    marking),
            lassoDetectionAlgorithm, lassoModelSummary
    );
}
 
開發者ID:shlee89,項目名稱:athena,代碼行數:13,代碼來源:LassoDistJob.java

示例4: AgePredictModel

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public AgePredictModel(String languageCode, LassoModel agePredictModel, String[] vocabulary,
		   AgeClassifyContextGeneratorWrapper wrapper) {

this.languageCode = languageCode;
this.model = agePredictModel;
this.vocabulary = vocabulary;
this.wrapper = wrapper;
   }
 
開發者ID:USCDataScience,項目名稱:AgePredictor,代碼行數:9,代碼來源:AgePredictModel.java

示例5: test

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
/**
 * This method applies lasso regression using a given model and a dataset
 *
 * @param lassoModel        Lasso regression model
 * @param testingDataset    Testing dataset as a JavaRDD of LabeledPoints
 * @return                  Tuple2 containing predicted values and labels
 */
public JavaRDD<Tuple2<Double, Double>> test(final LassoModel lassoModel,
        JavaRDD<LabeledPoint> testingDataset) {
    return testingDataset.map(
            new Function<LabeledPoint, Tuple2<Double, Double>>() {
                private static final long serialVersionUID = -156144873494491437L;

                public Tuple2<Double, Double> call(LabeledPoint labeledPoint) {
                    Double predicted = lassoModel.predict(labeledPoint.features());
                    return new Tuple2<Double, Double>(predicted, labeledPoint.label());
                }
            }
    );
}
 
開發者ID:wso2-attic,項目名稱:carbon-ml,代碼行數:21,代碼來源:LassoRegression.java

示例6: setModel

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public void setModel(LassoModel model) {
    this.model = model;
}
 
開發者ID:shlee89,項目名稱:athena,代碼行數:4,代碼來源:LassoDetectionModel.java

示例7: validate

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public void validate(JavaPairRDD<Object, BSONObject> mongoRDD,
                         AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                         LassoDetectionModel lassoDetectionModel,
                         LassoValidationSummary lassoValidationSummary) {
        List<AthenaFeatureField> listOfTargetFeatures = athenaMLFeatureConfiguration.getListOfTargetFeatures();
        Map<AthenaFeatureField, Integer> weight = athenaMLFeatureConfiguration.getWeight();
        Marking marking = lassoDetectionModel.getMarking();
        LassoModel model = (LassoModel) lassoDetectionModel.getDetectionModel();
        Normalizer normalizer = new Normalizer();

        int numberOfTargetValue = listOfTargetFeatures.size();

        JavaRDD<Tuple2<Double, Double>> valuesAndPreds = mongoRDD.map(
                (Function<Tuple2<Object, BSONObject>, Tuple2<Double, Double>>) t -> {

                    BSONObject feature = (BSONObject) t._2().get(AthenaFeatureField.FEATURE);
                    BSONObject idx = (BSONObject) t._2();
                    int originLabel = marking.checkClassificationMarkingElements(idx, feature);

                    double[] values = new double[numberOfTargetValue];
                    for (int j = 0; j < numberOfTargetValue; j++) {
                        values[j] = 0;
                        if (feature.containsField(listOfTargetFeatures.get(j).getValue())) {
                            Object obj = feature.get(listOfTargetFeatures.get(j).getValue());
                            if (obj instanceof Long) {
                                values[j] = (Long) obj;
                            } else if (obj instanceof Double) {
                                values[j] = (Double) obj;
                            } else if (obj instanceof Boolean) {
                                values[j] = (Boolean) obj ? 1 : 0;
                            } else {
                                System.out.println("not supported");
//                                return;
                            }

                            //check weight
                            if (weight.containsKey(listOfTargetFeatures.get(j))) {
                                values[j] *= weight.get(listOfTargetFeatures.get(j));
                            }

                            //check absolute
                            if (athenaMLFeatureConfiguration.isAbsolute()) {
                                values[j] = Math.abs(values[j]);
                            }
                        }
                    }

                    Vector normedForVal;
                    if (athenaMLFeatureConfiguration.isNormalization()) {
                        normedForVal = normalizer.transform(Vectors.dense(values));
                    } else {
                        normedForVal = Vectors.dense(values);
                    }

                    LabeledPoint p = new LabeledPoint(originLabel, normedForVal);
                    //Only SVM!!

                    double prediction = model.predict(p.features());


                    lassoValidationSummary.addEntry();
                    return new Tuple2<Double, Double>(prediction, p.label());
                });

        double MSE = new JavaDoubleRDD(valuesAndPreds.map(
                new Function<Tuple2<Double, Double>, Object>() {
                    public Object call(Tuple2<Double, Double> pair) {
                        return Math.pow(pair._1() - pair._2(), 2.0);
                    }
                }
        ).rdd()).mean();
        lassoValidationSummary.setMSE(MSE);
        lassoValidationSummary.setLassoDetectionAlgorithm((LassoDetectionAlgorithm) lassoDetectionModel.getDetectionAlgorithm());
    }
 
開發者ID:shlee89,項目名稱:athena,代碼行數:75,代碼來源:LassoDistJob.java

示例8: generateLassoAthenaDetectionModel

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public LassoDetectionModel generateLassoAthenaDetectionModel(JavaSparkContext sc,
                                                             FeatureConstraint featureConstraint,
                                                             AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                                                             DetectionAlgorithm detectionAlgorithm,
                                                             Indexing indexing,
                                                             Marking marking) {
    LassoModelSummary lassoModelSummary = new LassoModelSummary(
            sc.sc(), indexing, marking);

    long start = System.nanoTime(); // <-- start

    LassoDetectionAlgorithm lassoDetectionAlgorithm = (LassoDetectionAlgorithm) detectionAlgorithm;

    LassoDetectionModel lassoDetectionModel = new LassoDetectionModel();

    lassoDetectionModel.setLassoDetectionAlgorithm(lassoDetectionAlgorithm);
    lassoModelSummary.setLassoDetectionAlgorithm(lassoDetectionAlgorithm);
    lassoDetectionModel.setFeatureConstraint(featureConstraint);
    lassoDetectionModel.setAthenaMLFeatureConfiguration(athenaMLFeatureConfiguration);
    lassoDetectionModel.setIndexing(indexing);
    lassoDetectionModel.setMarking(marking);

    JavaPairRDD<Object, BSONObject> mongoRDD;
    mongoRDD = sc.newAPIHadoopRDD(
            mongodbConfig,            // Configuration
            MongoInputFormat.class,   // InputFormat: read from a live cluster.
            Object.class,             // Key class
            BSONObject.class          // Value class
    );

    LassoDistJob lassoDistJob = new LassoDistJob();

    LassoModel lassoModel = lassoDistJob.generateDecisionTreeWithPreprocessing(mongoRDD,
            athenaMLFeatureConfiguration, lassoDetectionAlgorithm, marking, lassoModelSummary);


    lassoDetectionModel.setModel(lassoModel);
    long end = System.nanoTime(); // <-- start
    long time = end - start;
    lassoModelSummary.setTotalLearningTime(time);
    lassoDetectionModel.setClassificationModelSummary(lassoModelSummary);

    return lassoDetectionModel;
}
 
開發者ID:shlee89,項目名稱:athena,代碼行數:45,代碼來源:MachineLearningManagerImpl.java

示例9: detection

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public double detection(Vector value) {

        if (detectionModel instanceof KMeansDetectionModel) {
//            int cluster = kMeansModel.predict(value);
//            System.out.println(value.toString() + "!!!!!!!!!!!!!!!!!" + cluster);
            return kMeansModel.predict(value);

        } else if (detectionModel instanceof GaussianMixtureDetectionModel) {
            return gaussianMixtureModel.predict(value);

        } else if (detectionModel instanceof DecisionTreeDetectionModel) {
            return decisionTreeModel.predict(value);

        } else if (detectionModel instanceof NaiveBayesDetectionModel) {
            return naiveBayesModel.predict(value);

        } else if (detectionModel instanceof RandomForestDetectionModel) {
            return randomForestModel.predict(value);

        } else if (detectionModel instanceof GradientBoostedTreesDetectionModel) {
            return gradientBoostedTreesModel.predict(value);

        } else if (detectionModel instanceof SVMDetectionModel) {
            return svmModel.predict(value);

        } else if (detectionModel instanceof LogisticRegressionDetectionModel) {
            return logisticRegressionModel.predict(value);

        } else if (detectionModel instanceof LinearRegressionDetectionModel) {
            return linearRegressionModel.predict(value);

        } else if (detectionModel instanceof LassoDetectionModel) {
            this.lassoModel = (LassoModel) detectionModel.getDetectionModel();

        } else if (detectionModel instanceof RidgeRegressionDetectionModel) {
            return ridgeRegressionModel.predict(value);

        } else {
            //not supported ML model
            System.out.println("Not supported model");
            return 0;
        }

        return 0;
    }
 
開發者ID:shlee89,項目名稱:athena,代碼行數:46,代碼來源:OnlineFeatureHandler.java

示例10: getModel

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public LassoModel getModel() {
return this.model;
   }
 
開發者ID:USCDataScience,項目名稱:AgePredictor,代碼行數:4,代碼來源:AgePredictModel.java

示例11: buildLassoRegressionModel

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
/**
 * This method builds a lasso regression model
 *
 * @param sparkContext JavaSparkContext initialized with the application
 * @param modelID Model ID
 * @param trainingData Training data as a JavaRDD of LabeledPoints
 * @param testingData Testing data as a JavaRDD of LabeledPoints
 * @param workflow Machine learning workflow
 * @param mlModel Deployable machine learning model
 * @throws MLModelBuilderException
 */
private ModelSummary buildLassoRegressionModel(JavaSparkContext sparkContext, long modelID,
                                               JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testingData, Workflow workflow, MLModel mlModel,
                                               SortedMap<Integer, String> includedFeatures) throws MLModelBuilderException {
    try {
        LassoRegression lassoRegression = new LassoRegression();
        Map<String, String> hyperParameters = workflow.getHyperParameters();
        LassoModel lassoModel = lassoRegression.train(trainingData,
                Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)),
                Double.parseDouble(hyperParameters.get(MLConstants.LEARNING_RATE)),
                Double.parseDouble(hyperParameters.get(MLConstants.REGULARIZATION_PARAMETER)),
                Double.parseDouble(hyperParameters.get(MLConstants.SGD_DATA_FRACTION)));

        // remove from cache
        trainingData.unpersist();
        // add test data to cache
        testingData.cache();

        Vector weights = lassoModel.weights();
        if (!isValidWeights(weights)) {
            throw new MLModelBuilderException("Weights of the model generated are null or infinity. [Weights] "
                    + vectorToString(weights));
        }
        JavaRDD<Tuple2<Double, Double>> predictionsAndLabels = lassoRegression.test(lassoModel, testingData)
                .cache();
        ClassClassificationAndRegressionModelSummary regressionModelSummary = SparkModelUtils
                .generateRegressionModelSummary(sparkContext, testingData, predictionsAndLabels);

        // remove from cache
        testingData.unpersist();

        mlModel.setModel(new MLGeneralizedLinearModel(lassoModel));

        List<FeatureImportance> featureWeights = getFeatureWeights(includedFeatures, lassoModel.weights().toArray());
        regressionModelSummary.setFeatures(includedFeatures.values().toArray(new String[0]));
        regressionModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.LASSO_REGRESSION.toString());
        regressionModelSummary.setFeatureImportance(featureWeights);

        RegressionMetrics regressionMetrics = getRegressionMetrics(sparkContext, predictionsAndLabels);

        predictionsAndLabels.unpersist();

        Double meanSquaredError = regressionMetrics.meanSquaredError();
        regressionModelSummary.setMeanSquaredError(meanSquaredError);
        regressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

        return regressionModelSummary;
    } catch (Exception e) {
        throw new MLModelBuilderException("An error occurred while building lasso regression model: "
                + e.getMessage(), e);
    }
}
 
開發者ID:wso2-attic,項目名稱:carbon-ml,代碼行數:63,代碼來源:SupervisedSparkModelBuilder.java

示例12: predictAge

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
public double predictAge(String document) throws InvalidFormatException, IOException {
	FeatureGenerator[] featureGenerators = model.getContext().getFeatureGenerators();

	List<Row> data = new ArrayList<Row>();

	String[] tokens = tokenizer.tokenize(document);

	double prob[] = classify.getProbabilities(tokens);
	String category = classify.getBestCategory(prob);

	Collection<String> context = new ArrayList<String>();

	for (FeatureGenerator featureGenerator : featureGenerators) {
		Collection<String> extractedFeatures = featureGenerator.extractFeatures(tokens);
		context.addAll(extractedFeatures);
	}

	if (category != null) {
		for (int i = 0; i < tokens.length / 18; i++) {
			context.add("cat=" + category);
		}
	}

	if (context.size() > 0) {
		data.add(RowFactory.create(document, context.toArray()));
	}

	StructType schema = new StructType(
			new StructField[] { new StructField("document", DataTypes.StringType, false, Metadata.empty()),
					new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) });

	Dataset<Row> df = spark.createDataFrame(data, schema);

	CountVectorizerModel cvm = new CountVectorizerModel(model.getVocabulary()).setInputCol("text")
			.setOutputCol("feature");

	Dataset<Row> eventDF = cvm.transform(df);

	Normalizer normalizer = new Normalizer().setInputCol("feature").setOutputCol("normFeature").setP(1.0);

	JavaRDD<Row> normEventDF = normalizer.transform(eventDF).javaRDD();

	Row event = normEventDF.first();

	SparseVector sp = (SparseVector) event.getAs("normFeature");

	final LassoModel linModel = model.getModel();

	Vector testData = Vectors.sparse(sp.size(), sp.indices(), sp.values());
	return linModel.predict(testData.compressed());

}
 
開發者ID:USCDataScience,項目名稱:AgePredictor,代碼行數:53,代碼來源:AgePredicterLocal.java

示例13: train

import org.apache.spark.mllib.regression.LassoModel; //導入依賴的package包/類
/**
 * This method uses stochastic gradient descent (SGD) algorithm to train a lasso regression model
 *
 * @param trainingDataset           Training dataset as a JavaRDD of LabeledPoints
 * @param noOfIterations            Number of iterarations
 * @param initialLearningRate       Initial learning rate (SGD step size)
 * @param regularizationParameter   Regularization parameter
 * @param miniBatchFraction         SGD minibatch fraction
 * @return                          Lasso regression model
 */
public LassoModel train(JavaRDD<LabeledPoint> trainingDataset, int noOfIterations, double initialLearningRate, 
        double regularizationParameter, double miniBatchFraction) {
    return LassoWithSGD.train(trainingDataset.rdd(), noOfIterations, initialLearningRate, regularizationParameter,
            miniBatchFraction);
}
 
開發者ID:wso2-attic,項目名稱:carbon-ml,代碼行數:16,代碼來源:LassoRegression.java


注:本文中的org.apache.spark.mllib.regression.LassoModel類示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。