当前位置: 首页>>代码示例>>Java>>正文


Java MLUtils类代码示例

本文整理汇总了Java中org.apache.spark.mllib.util.MLUtils的典型用法代码示例。如果您正苦于以下问题:Java MLUtils类的具体用法?Java MLUtils怎么用?Java MLUtils使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


MLUtils类属于org.apache.spark.mllib.util包,在下文中一共展示了MLUtils类的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: shouldExportAndImportCorrectly

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel);

    //Import it back
    LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);

    //check if they are exactly equal with respect to their fields
    //it maybe edge cases eg. order of elements in the list is changed
    assertEquals(lrmodel.intercept(), importedModel.getIntercept(), 0.01);
    assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), 0.01);
    assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), 0.01);
    assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), 0.01);
    for (int i = 0; i < importedModel.getNumFeatures(); i++)
        assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], 0.01);

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java

示例2: shouldExportAndImportCorrectly

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, null);

    //Import it back
    LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);

    //check if they are exactly equal with respect to their fields
    //it maybe edge cases eg. order of elements in the list is changed
    assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON);
    assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON);
    assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON);
    assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), EPSILON);
    for (int i = 0; i < importedModel.getNumFeatures(); i++)
        assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON);

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java

示例3: testLogisticRegression

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = trainingData.collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, 0.01);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:30,代码来源:LogisticRegressionBridgeTest.java

示例4: testLogisticRegression

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";

    Dataset<Row> trainingData = spark.read().format("libsvm").load(datapath);

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD().collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features().asML();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, 0.01);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:31,代码来源:LogisticRegression1BridgeTest.java

示例5: testLogisticRegression

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = trainingData.collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:30,代码来源:LogisticRegressionBridgeTest.java

示例6: testLogisticRegression

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";

    DataFrame trainingData = sqlContext.read().format("libsvm").load(datapath);

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, trainingData);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:31,代码来源:LogisticRegression1BridgeTest.java

示例7: main

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
public static void main(String args[]){

		SparkConf configuration = new SparkConf().setMaster("local[4]").setAppName("Any");
		JavaSparkContext sc = new JavaSparkContext(configuration);

		// Load and parse the data file.
		String input = "data/rf-data.txt";
		JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), input).toJavaRDD();
		// Split the data into training and test sets (30% held out for testing)
		JavaRDD<LabeledPoint>[] dataSplits = data.randomSplit(new double[]{0.7, 0.3});
		JavaRDD<LabeledPoint> trainingData = dataSplits[0];
		JavaRDD<LabeledPoint> testData = dataSplits[1];

		// Train a RandomForest model.
		Integer numClasses = 2;
		HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();//  Empty categoricalFeaturesInfo indicates all features are continuous.
		Integer numTrees = 3; // Use more in practice.
		String featureSubsetStrategy = "auto"; // Let the algorithm choose.
		String impurity = "gini";
		Integer maxDepth = 5;
		Integer maxBins = 32;
		Integer seed = 12345;

		final RandomForestModel rfModel = RandomForest.trainClassifier(trainingData, numClasses,
				categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
				seed);

		// Evaluate model on test instances and compute test error
		JavaPairRDD<Double, Double> label =
				testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
					public Tuple2<Double, Double> call(LabeledPoint p) {
						return new Tuple2<Double, Double>(rfModel.predict(p.features()), p.label());
					}
				});

		Double testError =
				1.0 * label.filter(new Function<Tuple2<Double, Double>, Boolean>() {
					public Boolean call(Tuple2<Double, Double> pl) {
						return !pl._1().equals(pl._2());
					}
				}).count() / testData.count();

		System.out.println("Test Error: " + testError);
		System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());
	}
 
开发者ID:PacktPublishing,项目名称:Java-Data-Science-Cookbook,代码行数:46,代码来源:RandomForestMlib.java

示例8: main

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
public static void main(String[] args) {
  MudrodEngine me = new MudrodEngine();

  JavaSparkContext jsc = me.startSparkDriver().sc;

  String path = SparkSVM.class.getClassLoader().getResource("inputDataForSVM_spark.txt").toString();
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();

  // Run training algorithm to build the model.
  int numIterations = 100;
  final SVMModel model = SVMWithSGD.train(data.rdd(), numIterations);

  // Save and load model
  model.save(jsc.sc(), SparkSVM.class.getClassLoader().getResource("javaSVMWithSGDModel").toString());

  jsc.sc().stop();

}
 
开发者ID:apache,项目名称:incubator-sdap-mudrod,代码行数:19,代码来源:SparkSVM.java

示例9: testFromSvmLightBackprop

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void testFromSvmLightBackprop() throws Exception {
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), new ClassPathResource("iris_svmLight_0.txt").getTempFileFromArchive().getAbsolutePath()).toJavaRDD().map(new Function<LabeledPoint, LabeledPoint>() {
        @Override
        public LabeledPoint call(LabeledPoint v1) throws Exception {
            return new LabeledPoint(v1.label(), Vectors.dense(v1.features().toArray()));
        }
    }).cache();
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;

    DataSet d = new IrisDataSetIterator(150,150).next();
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(123)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .iterations(10)
            .list()
            .layer(0, new DenseLayer.Builder()
                    .nIn(4).nOut(100)
                    .weightInit(WeightInit.XAVIER)
                    .activation("relu")
                    .build())
            .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                    .nIn(100).nOut(3)
                    .activation("softmax")
                    .weightInit(WeightInit.XAVIER)
                    .build())
            .backprop(true)
            .build();



    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    System.out.println("Initializing network");

    SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc,conf,new ParameterAveragingTrainingMaster(true,numExecutors(),1,5,1,0));

    MultiLayerNetwork network2 = master.fitLabeledPoint(data);
    Evaluation evaluation = new Evaluation();
    evaluation.eval(d.getLabels(), network2.output(d.getFeatureMatrix()));
    System.out.println(evaluation.stats());


}
 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-Hadoop,代码行数:45,代码来源:TestSparkMultiLayerParameterAveraging.java

示例10: testFromSvmLight

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void testFromSvmLight() throws Exception {
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive().getAbsolutePath()).toJavaRDD().map(new Function<LabeledPoint, LabeledPoint>() {
        @Override
        public LabeledPoint call(LabeledPoint v1) throws Exception {
            return new LabeledPoint(v1.label(), Vectors.dense(v1.features().toArray()));
        }
    }).cache();

    DataSet d = new IrisDataSetIterator(150,150).next();
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(123)
            .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
            .iterations(100).miniBatch(true)
            .maxNumLineSearchIterations(10)
            .list()
            .layer(0, new RBM.Builder(RBM.HiddenUnit.RECTIFIED, RBM.VisibleUnit.GAUSSIAN)
                    .nIn(4).nOut(100)
                    .weightInit(WeightInit.XAVIER)
                    .activation("relu")
                    .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                    .nIn(100).nOut(3)
                    .activation("softmax")
                    .weightInit(WeightInit.XAVIER)
                    .build())
            .backprop(false)
            .build();



    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    System.out.println("Initializing network");
    SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc,getBasicConf(),new ParameterAveragingTrainingMaster(true,numExecutors(),1,5,1,0));

    MultiLayerNetwork network2 = master.fitLabeledPoint(data);
    Evaluation evaluation = new Evaluation();
    evaluation.eval(d.getLabels(), network2.output(d.getFeatureMatrix()));
    System.out.println(evaluation.stats());
}
 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-Hadoop,代码行数:42,代码来源:TestSparkMultiLayerParameterAveraging.java

示例11: libsvmToBinaryBlock

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
/**
 * Converts a libsvm text input file into two binary block matrices for features 
 * and labels, and saves these to the specified output files. This call also deletes 
 * existing files at the specified output locations, as well as determines and 
 * writes the meta data files of both output matrices. 
 * <p>
 * Note: We use {@code org.apache.spark.mllib.util.MLUtils.loadLibSVMFile} for parsing 
 * the libsvm input files in order to ensure consistency with Spark.
 * 
 * @param sc java spark context
 * @param pathIn path to libsvm input file
 * @param pathX path to binary block output file of features
 * @param pathY path to binary block output file of labels
 * @param mcOutX matrix characteristics of output matrix X
 * @throws DMLRuntimeException if output path not writable or conversion failure
 */
public static void libsvmToBinaryBlock(JavaSparkContext sc, String pathIn, 
		String pathX, String pathY, MatrixCharacteristics mcOutX) 
	throws DMLRuntimeException
{
	if( !mcOutX.dimsKnown() )
		throw new DMLRuntimeException("Matrix characteristics "
			+ "required to convert sparse input representation.");
	try {
		//cleanup existing output files
		MapReduceTool.deleteFileIfExistOnHDFS(pathX);
		MapReduceTool.deleteFileIfExistOnHDFS(pathY);
		
		//convert libsvm to labeled points
		int numFeatures = (int) mcOutX.getCols();
		int numPartitions = SparkUtils.getNumPreferredPartitions(mcOutX, null);
		JavaRDD<org.apache.spark.mllib.regression.LabeledPoint> lpoints = 
				MLUtils.loadLibSVMFile(sc.sc(), pathIn, numFeatures, numPartitions).toJavaRDD();
		
		//append row index and best-effort caching to avoid repeated text parsing
		JavaPairRDD<org.apache.spark.mllib.regression.LabeledPoint,Long> ilpoints = 
				lpoints.zipWithIndex().persist(StorageLevel.MEMORY_AND_DISK()); 
		
		//extract labels and convert to binary block
		MatrixCharacteristics mc1 = new MatrixCharacteristics(mcOutX.getRows(), 1, 
				mcOutX.getRowsPerBlock(), mcOutX.getColsPerBlock(), -1);
		LongAccumulator aNnz1 = sc.sc().longAccumulator("nnz");
		JavaPairRDD<MatrixIndexes,MatrixBlock> out1 = ilpoints
				.mapPartitionsToPair(new LabeledPointToBinaryBlockFunction(mc1, true, aNnz1));
		int numPartitions2 = SparkUtils.getNumPreferredPartitions(mc1, null);
		out1 = RDDAggregateUtils.mergeByKey(out1, numPartitions2, false);
		out1.saveAsHadoopFile(pathY, MatrixIndexes.class, MatrixBlock.class, SequenceFileOutputFormat.class);
		mc1.setNonZeros(aNnz1.value()); //update nnz after triggered save
		MapReduceTool.writeMetaDataFile(pathY+".mtd", ValueType.DOUBLE, mc1, OutputInfo.BinaryBlockOutputInfo);
		
		//extract data and convert to binary block
		MatrixCharacteristics mc2 = new MatrixCharacteristics(mcOutX.getRows(), mcOutX.getCols(),
				mcOutX.getRowsPerBlock(), mcOutX.getColsPerBlock(), -1);
		LongAccumulator aNnz2 = sc.sc().longAccumulator("nnz");
		JavaPairRDD<MatrixIndexes,MatrixBlock> out2 = ilpoints
				.mapPartitionsToPair(new LabeledPointToBinaryBlockFunction(mc2, false, aNnz2));
		out2 = RDDAggregateUtils.mergeByKey(out2, numPartitions, false);
		out2.saveAsHadoopFile(pathX, MatrixIndexes.class, MatrixBlock.class, SequenceFileOutputFormat.class);
		mc2.setNonZeros(aNnz2.value()); //update nnz after triggered save
		MapReduceTool.writeMetaDataFile(pathX+".mtd", ValueType.DOUBLE, mc2, OutputInfo.BinaryBlockOutputInfo);
		
		//asynchronous cleanup of cached intermediates
		ilpoints.unpersist(false);
	}
	catch(IOException ex) {
		throw new DMLRuntimeException(ex);
	}
}
 
开发者ID:apache,项目名称:systemml,代码行数:69,代码来源:RDDConverterUtils.java

示例12: main

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
public static void main(String[] args) throws IOException {

        SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
        // just run this locally
        sparkConf.setMaster("local[" + Runtime.getRuntime().availableProcessors() + "]");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);

        // Load and parse the data file.
        String datapath = "/media/an/fixes.libsvm";

        // the feature names are substituted into the model debugString later to
        // make it readable
        List<String> names = Arrays.asList("lat", "lon", "speedKnots", "courseHeadingDiff",
                "preEffectiveSpeedKnots", "preError", "postEffectiveSpeedKnots", "postError");
        List<String> classifications = Arrays.asList("other", "moored", "anchored");

        JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
        // Split the data into training and test sets (30% held out for testing)
        JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] { 0.7, 0.3 });
        JavaRDD<LabeledPoint> trainingData = splits[0];
        JavaRDD<LabeledPoint> testData = splits[1];

        // Set parameters.
        // Empty categoricalFeaturesInfo indicates all features are continuous.
        Integer numClassifications = classifications.size();
        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
        String impurity = "gini";
        Integer maxDepth = 8;
        Integer maxBins = 32;

        // Train a DecisionTree model for classification.
        final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData,
                numClassifications, categoricalFeaturesInfo, impurity, maxDepth, maxBins);

        // Evaluate model on test instances and compute test error
        Double testErr = (double) testData
        // pair up actual and predicted classification numerical representation
                .map(toPredictionAndActual(model))
                // get the ones that don't match
                .filter(predictionWrong())
                // count them
                .count()
        // divide by total count to get ratio failing test
                / testData.count();

        // Save and load model to demo possible usage in prediction mode
        String modelPath = "target/myModelPath";
        FileUtils.deleteDirectory(new File(modelPath));
        model.save(sc.sc(), modelPath);
        DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), modelPath);

        System.out.println("Test Error: " + testErr);

        String s = useNames(model.toDebugString(), names, classifications);

        System.out.println("Learned classification tree model:\n" + s);

        FileOutputStream fos = new FileOutputStream("target/model.txt");
        fos.write(("Test Error: " + testErr + "\n").getBytes());
        fos.write(s.getBytes());
        fos.close();

    }
 
开发者ID:amsa-code,项目名称:risky,代码行数:64,代码来源:AnchoredTrainerMain.java

示例13: testFromSvmLightBackprop

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void testFromSvmLightBackprop() throws Exception {
    JavaRDD<LabeledPoint> data = MLUtils
                    .loadLibSVMFile(sc.sc(),
                                    new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive()
                                                    .getAbsolutePath())
                    .toJavaRDD().map(new Function<LabeledPoint, LabeledPoint>() {
                        @Override
                        public LabeledPoint call(LabeledPoint v1) throws Exception {
                            return new LabeledPoint(v1.label(), Vectors.dense(v1.features().toArray()));
                        }
                    });
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;

    DataSet d = new IrisDataSetIterator(150, 150).next();
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123)
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.RELU).build())
                    .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                    LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3)
                                                    .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER)
                                                    .build())
                    .backprop(true).build();



    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    System.out.println("Initializing network");

    SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, conf,
                    new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));

    MultiLayerNetwork network2 = master.fitLabeledPoint(data);
    Evaluation evaluation = new Evaluation();
    evaluation.eval(d.getLabels(), network2.output(d.getFeatureMatrix()));
    System.out.println(evaluation.stats());


}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:42,代码来源:TestSparkMultiLayerParameterAveraging.java

示例14: testFromSvmLight

import org.apache.spark.mllib.util.MLUtils; //导入依赖的package包/类
@Test
public void testFromSvmLight() throws Exception {
    JavaRDD<LabeledPoint> data = MLUtils
                    .loadLibSVMFile(sc.sc(),
                                    new ClassPathResource("svmLight/iris_svmLight_0.txt").getTempFileFromArchive()
                                                    .getAbsolutePath())
                    .toJavaRDD().map(new Function<LabeledPoint, LabeledPoint>() {
                        @Override
                        public LabeledPoint call(LabeledPoint v1) throws Exception {
                            return new LabeledPoint(v1.label(), Vectors.dense(v1.features().toArray()));
                        }
                    });

    DataSet d = new IrisDataSetIterator(150, 150).next();
    MultiLayerConfiguration conf =
                    new NeuralNetConfiguration.Builder().seed(123)
                                    .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
                                    .miniBatch(true).maxNumLineSearchIterations(10)
                                    .list().layer(0,
                                                    new DenseLayer.Builder().nIn(4).nOut(100)
                                                            .weightInit(WeightInit.XAVIER)
                                                            .activation(Activation.RELU)
                                                            .build())
                                    .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                                    LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3)
                                                                    .activation(Activation.SOFTMAX)
                                                                    .weightInit(WeightInit.XAVIER).build())
                                    .backprop(false).build();



    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    System.out.println("Initializing network");
    SparkDl4jMultiLayer master = new SparkDl4jMultiLayer(sc, getBasicConf(),
                    new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0));

    MultiLayerNetwork network2 = master.fitLabeledPoint(data);
    Evaluation evaluation = new Evaluation();
    evaluation.eval(d.getLabels(), network2.output(d.getFeatureMatrix()));
    System.out.println(evaluation.stats());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:43,代码来源:TestSparkMultiLayerParameterAveraging.java


注:本文中的org.apache.spark.mllib.util.MLUtils类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。