当前位置: 首页>>代码示例>>Java>>正文


Java JavaDoubleRDD类代码示例

本文整理汇总了Java中org.apache.spark.api.java.JavaDoubleRDD的典型用法代码示例。如果您正苦于以下问题:Java JavaDoubleRDD类的具体用法?Java JavaDoubleRDD怎么用?Java JavaDoubleRDD使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


JavaDoubleRDD类属于org.apache.spark.api.java包,在下文中一共展示了JavaDoubleRDD类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: main

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public static void main(String[] args) {
    //Sample test data - All numbers from 1 to 99999
    List<Double> testData = IntStream.range(1, 100000).mapToDouble(d -> d).collect(ArrayList::new, ArrayList::add,
                                                                                 ArrayList::addAll);

    JavaDoubleRDD rdd = sc.parallelizeDoubles(testData);

    LOGGER.info("Mean: " + rdd.mean());

    //For efficiency, use StatCounter if more than one stats are required.
    StatCounter statCounter = rdd.stats();

    LOGGER.info("Using StatCounter");
    LOGGER.info("Count:    " + statCounter.count());
    LOGGER.info("Min:      " + statCounter.min());
    LOGGER.info("Max:      " + statCounter.max());
    LOGGER.info("Sum:      " + statCounter.sum());
    LOGGER.info("Mean:     " + statCounter.mean());
    LOGGER.info("Variance: " + statCounter.variance());
    LOGGER.info("Stdev:    " + statCounter.stdev());
}
 
开发者ID:sujee81,项目名称:SparkApps,代码行数:22,代码来源:Main.java

示例2: contentSizeStats

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public static final @Nullable Tuple4<Long, Long, Long, Long> contentSizeStats(
    JavaRDD<ApacheAccessLog> accessLogRDD) {
  JavaDoubleRDD contentSizes =
    accessLogRDD.mapToDouble(new GetContentSize()).cache();
  long count = contentSizes.count();
  if (count == 0) {
    return null;
  }
  Object ordering = Ordering.natural();
  final Comparator<Double> cmp = (Comparator<Double>)ordering;
  
  return new Tuple4<>(count,
                      contentSizes.reduce(new SumReducer()).longValue(),
                      contentSizes.min(cmp).longValue(),
                      contentSizes.max(cmp).longValue());
}
 
开发者ID:holdenk,项目名称:learning-spark-examples,代码行数:17,代码来源:Functions.java

示例3: main

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
	String master;
	if (args.length > 0) {
     master = args[0];
	} else {
		master = "local";
	}
	JavaSparkContext sc = new JavaSparkContext(
     master, "basicmaptodouble", System.getenv("SPARK_HOME"), System.getenv("JARS"));
   JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
   JavaDoubleRDD result = rdd.mapToDouble(
     new DoubleFunction<Integer>() {
       public double call(Integer x) {
         double y = (double) x;
         return y * y;
       }
     });
   System.out.println(StringUtils.join(result.collect(), ","));
}
 
开发者ID:holdenk,项目名称:learning-spark-examples,代码行数:20,代码来源:BasicMapToDouble.java

示例4: main

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
/**
 * Cluster the C-alpha chains of a set of PDB ids.
 * @param args the input args - currently none taken
 * @throws IOException an error reading from the URL or the seqeunce file
 */
public static void main(String[] args) throws IOException {
	// Read the arguments
	Namespace ns = parseArgs(args);
	// Get the actual arguments
	String alignMethod = ns.getString("align");
	String filePath = ns.getString("hadoop");
	int minLength = ns.getInt("minlength");
	double sample = ns.getDouble("sample");
	boolean useFiles = ns.getBoolean("files");
	
	// Get the list of PDB ids
	List<String> pdbIdList = ns.<String> getList("pdbId");

	// Get the chains that correpspond to that
	JavaPairRDD<String, Atom[]>  chainRDD;
	if(pdbIdList.size()>0){
		if(useFiles==true){
			StructureDataRDD structureDataRDD = new StructureDataRDD(
					BiojavaSparkUtils.getFromList(convertToFiles(pdbIdList))
					.mapToPair(t -> new Tuple2<String, StructureDataInterface>(t._1, BiojavaSparkUtils.convertToStructDataInt(t._2))));
			chainRDD = BiojavaSparkUtils.getChainRDD(structureDataRDD, minLength);

		}
		else{
			chainRDD = BiojavaSparkUtils.getChainRDD(pdbIdList, minLength);
		}
	}
	else if(!filePath.equals(defaultPath)){
		chainRDD = BiojavaSparkUtils.getChainRDD(filePath, minLength, sample);
	}
	else{
		System.out.println("Must specify PDB ids or an hadoop sequence file");
		return;
	}

	System.out.println("Analysisng " + chainRDD.count() + " chains");
	JavaPairRDD<Tuple2<String,Atom[]>,Tuple2<String, Atom[]>> comparisons = SparkUtils.getHalfCartesian(chainRDD, chainRDD.getNumPartitions());
	JavaRDD<Tuple3<String, String,  AFPChain>> similarities = comparisons.map(t -> new Tuple3<String, String, AFPChain>(t._1._1, t._2._1, 
			AlignmentTools.getBiojavaAlignment(t._1._2, t._2._2, alignMethod)));
	JavaRDD<Tuple6<String, String, Double, Double, Double, Double>> allScores = similarities.map(t -> new Tuple6<String, String, Double, Double, Double, Double>(
			t._1(), t._2(), t._3().getTMScore(), t._3().getTotalRmsdOpt(),  (double) t._3().getTotalLenOpt(),  t._3().getAlignScore())).cache();
	if(alignMethod.equals("DUMMY")){
		JavaDoubleRDD doubleDist = allScores.mapToDouble(t -> t._3());
		System.out.println("Average dist: "+doubleDist.mean());
	}
	else{
		writeData(allScores);
	}
}
 
开发者ID:biojava,项目名称:biojava-spark,代码行数:55,代码来源:ChainAligner.java

示例5: main

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public static void main(String[] args) throws UnknownHostException {
   // Obtain the Infinispan address
   String infinispanAddress = args[0];

   // Adjust log levels
   Logger.getLogger("org").setLevel(Level.WARN);

   // Create the remote cache manager
   Configuration build = new ConfigurationBuilder().addServer().host(infinispanAddress).build();
   RemoteCacheManager remoteCacheManager = new RemoteCacheManager(build);

   // Obtain the remote cache
   RemoteCache<Integer, Temperature> cache = remoteCacheManager.getCache();

   // Add some data
   cache.put(1, new Temperature(21, "London"));
   cache.put(2, new Temperature(34, "Rome"));
   cache.put(3, new Temperature(33, "Barcelona"));
   cache.put(4, new Temperature(8, "Oslo"));

   // Create java spark context
   SparkConf conf = new SparkConf().setAppName("infinispan-spark-simple-job");
   JavaSparkContext jsc = new JavaSparkContext(conf);

   // Create InfinispanRDD
   ConnectorConfiguration config = new ConnectorConfiguration().setServerList(infinispanAddress);

   JavaPairRDD<Integer, Temperature> infinispanRDD = InfinispanJavaRDD.createInfinispanRDD(jsc, config);

   // Convert RDD to RDD of doubles
   JavaDoubleRDD javaDoubleRDD = infinispanRDD.values().mapToDouble(Temperature::getValue);

   // Calculate average temperature
   Double meanTemp = javaDoubleRDD.mean();
   System.out.printf("\nAVERAGE TEMPERATURE: %f C\n", meanTemp);

   // Calculate standard deviation
   Double stdDev = javaDoubleRDD.sampleStdev();
   System.out.printf("STD DEVIATION: %f C\n ", stdDev);

   // Calculate histogram of temperatures
   System.out.println("TEMPERATURE HISTOGRAM:");
   double[] buckets = {0d, 10d, 20d, 30d, 40d};
   long[] histogram = javaDoubleRDD.histogram(buckets);

   for (int i = 0; i < buckets.length - 1; i++) {
      System.out.printf("Between %f C and %f C: %d cities\n", buckets[i], buckets[i + 1], histogram[i]);
   }
}
 
开发者ID:infinispan,项目名称:infinispan-simple-tutorials,代码行数:50,代码来源:SimpleSparkJob.java

示例6: getSlopes

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
/**
 * Action: Calculates the slope of a linear regression of every time series.
 *
 * Where: value = slope * timestamp
 * .. or:     y = slope * x
 *
 * @return the slopes (simple linear regression) of each an every time series in the RDD
 */
public JavaDoubleRDD getSlopes() {
    return this.mapToDouble((DoubleFunction<MetricTimeSeries>) mts -> {
                SimpleRegression regression = new SimpleRegression();
        mts.points().forEach(p -> regression.addData(p.getTimestamp(), p.getValue()));
                return regression.getSlope();
            }
    );
}
 
开发者ID:ChronixDB,项目名称:chronix.spark,代码行数:17,代码来源:ChronixRDD.java

示例7: validate

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public double validate(JavaRDD<Rating> predictionJavaRdd, CassandraJavaRDD<CassandraRow> validationsCassRdd) {
	JavaPairRDD<Tuple2<Integer, Integer>, Double> predictionsJavaPairs = JavaPairRDD.fromJavaRDD(predictionJavaRdd.map(new org.apache.spark.api.java.function.Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() {
		@Override
		public Tuple2<Tuple2<Integer, Integer>, Double> call(Rating pred) throws Exception {
			return new Tuple2<Tuple2<Integer, Integer>, Double>(new Tuple2<Integer, Integer>(pred.user(), pred.product()), pred.rating());
		}
		//
	}));
	JavaRDD<Rating> validationRatings = validationsCassRdd.map(new org.apache.spark.api.java.function.Function<CassandraRow, Rating>() {
		@Override
		public Rating call(CassandraRow validation) throws Exception {
			return new Rating(validation.getInt(RatingDO.USER_COL), validation.getInt(RatingDO.PRODUCT_COL), validation.getInt(RatingDO.RATING_COL));
		}
	
	});
	JavaRDD<Tuple2<Double, Double>> validationAndPredictions = JavaPairRDD.fromJavaRDD(validationRatings.map(new org.apache.spark.api.java.function.Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() {
	
		@Override
		public Tuple2<Tuple2<Integer, Integer>, Double> call(Rating validationRating) throws Exception {
			return new Tuple2<Tuple2<Integer, Integer>, Double>(new Tuple2<Integer, Integer>(validationRating.user(), validationRating.product()), validationRating.rating());
		}
	
	})).join(predictionsJavaPairs).values();
	
	double meanSquaredError = JavaDoubleRDD.fromRDD(validationAndPredictions.map(new org.apache.spark.api.java.function.Function<Tuple2<Double, Double>, Object>() {
		@Override
		public Object call(Tuple2<Double, Double> pair) throws Exception {
			Double err = pair._1() - pair._2();
			return (Object) (err * err);// No covariance! Need to cast
		}
	}).rdd()).mean();
	double rmse = Math.sqrt(meanSquaredError);
	return rmse;
	 
}
 
开发者ID:JoshuaFox,项目名称:spark-cassandra-collabfiltering,代码行数:36,代码来源:CollabFilterCassandra7.java

示例8: main

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public static void main(String[] args) {
String master;
if (args.length > 0) {
    master = args[0];
} else {
	master = "local";
}
JavaSparkContext sc = new JavaSparkContext(
    master, "basicmap", System.getenv("SPARK_HOME"), System.getenv("JARS"));
  JavaDoubleRDD input = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 1000.0));
  JavaDoubleRDD result = removeOutliers(input);
  System.out.println(StringUtils.join(result.collect(), ","));
}
 
开发者ID:holdenk,项目名称:learning-spark-examples,代码行数:14,代码来源:RemoveOutliers.java

示例9: removeOutliers

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
static JavaDoubleRDD removeOutliers(JavaDoubleRDD rdd) {
  final StatCounter summaryStats = rdd.stats();
  final Double stddev = Math.sqrt(summaryStats.variance());
  return rdd.filter(new Function<Double, Boolean>() { public Boolean call(Double x) {
        return (Math.abs(x - summaryStats.mean()) < 3 * stddev);
      }});
}
 
开发者ID:holdenk,项目名称:learning-spark-examples,代码行数:8,代码来源:RemoveOutliers.java

示例10: validate

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public void validate(JavaPairRDD<Object, BSONObject> mongoRDD,
                         AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                         RidgeRegressionDetectionModel ridgeRegressionDetectionModel,
                         RidgeRegressionValidationSummary ridgeRegressionValidationSummary) {
        List<AthenaFeatureField> listOfTargetFeatures = athenaMLFeatureConfiguration.getListOfTargetFeatures();
        Map<AthenaFeatureField, Integer> weight = athenaMLFeatureConfiguration.getWeight();
        Marking marking = ridgeRegressionDetectionModel.getMarking();
        RidgeRegressionModel model = (RidgeRegressionModel) ridgeRegressionDetectionModel.getDetectionModel();
        Normalizer normalizer = new Normalizer();

        int numberOfTargetValue = listOfTargetFeatures.size();

        JavaRDD<Tuple2<Double, Double>> valuesAndPreds = mongoRDD.map(
                (Function<Tuple2<Object, BSONObject>, Tuple2<Double, Double>>) t -> {

                    BSONObject feature = (BSONObject) t._2().get(AthenaFeatureField.FEATURE);
                    BSONObject idx = (BSONObject) t._2();
                    int originLabel = marking.checkClassificationMarkingElements(idx, feature);

                    double[] values = new double[numberOfTargetValue];
                    for (int j = 0; j < numberOfTargetValue; j++) {
                        values[j] = 0;
                        if (feature.containsField(listOfTargetFeatures.get(j).getValue())) {
                            Object obj = feature.get(listOfTargetFeatures.get(j).getValue());
                            if (obj instanceof Long) {
                                values[j] = (Long) obj;
                            } else if (obj instanceof Double) {
                                values[j] = (Double) obj;
                            } else if (obj instanceof Boolean) {
                                values[j] = (Boolean) obj ? 1 : 0;
                            } else {
                                System.out.println("not supported");
//                                return;
                            }

                            //check weight
                            if (weight.containsKey(listOfTargetFeatures.get(j))) {
                                values[j] *= weight.get(listOfTargetFeatures.get(j));
                            }

                            //check absolute
                            if (athenaMLFeatureConfiguration.isAbsolute()) {
                                values[j] = Math.abs(values[j]);
                            }
                        }
                    }

                    Vector normedForVal;
                    if (athenaMLFeatureConfiguration.isNormalization()) {
                        normedForVal = normalizer.transform(Vectors.dense(values));
                    } else {
                        normedForVal = Vectors.dense(values);
                    }

                    LabeledPoint p = new LabeledPoint(originLabel, normedForVal);
                    //Only SVM!!

                    double prediction = model.predict(p.features());


                    ridgeRegressionValidationSummary.addEntry();
                    return new Tuple2<Double, Double>(prediction, p.label());
                });

        double MSE = new JavaDoubleRDD(valuesAndPreds.map(
                new Function<Tuple2<Double, Double>, Object>() {
                    public Object call(Tuple2<Double, Double> pair) {
                        return Math.pow(pair._1() - pair._2(), 2.0);
                    }
                }
        ).rdd()).mean();
        ridgeRegressionValidationSummary.setMSE(MSE);
        ridgeRegressionValidationSummary.setRidgeRegressionDetectionAlgorithm((RidgeRegressionDetectionAlgorithm) ridgeRegressionDetectionModel.getDetectionAlgorithm());
    }
 
开发者ID:shlee89,项目名称:athena,代码行数:75,代码来源:RidgeRegressionDistJob.java

示例11: validate

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public void validate(JavaPairRDD<Object, BSONObject> mongoRDD,
                         AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                         LassoDetectionModel lassoDetectionModel,
                         LassoValidationSummary lassoValidationSummary) {
        List<AthenaFeatureField> listOfTargetFeatures = athenaMLFeatureConfiguration.getListOfTargetFeatures();
        Map<AthenaFeatureField, Integer> weight = athenaMLFeatureConfiguration.getWeight();
        Marking marking = lassoDetectionModel.getMarking();
        LassoModel model = (LassoModel) lassoDetectionModel.getDetectionModel();
        Normalizer normalizer = new Normalizer();

        int numberOfTargetValue = listOfTargetFeatures.size();

        JavaRDD<Tuple2<Double, Double>> valuesAndPreds = mongoRDD.map(
                (Function<Tuple2<Object, BSONObject>, Tuple2<Double, Double>>) t -> {

                    BSONObject feature = (BSONObject) t._2().get(AthenaFeatureField.FEATURE);
                    BSONObject idx = (BSONObject) t._2();
                    int originLabel = marking.checkClassificationMarkingElements(idx, feature);

                    double[] values = new double[numberOfTargetValue];
                    for (int j = 0; j < numberOfTargetValue; j++) {
                        values[j] = 0;
                        if (feature.containsField(listOfTargetFeatures.get(j).getValue())) {
                            Object obj = feature.get(listOfTargetFeatures.get(j).getValue());
                            if (obj instanceof Long) {
                                values[j] = (Long) obj;
                            } else if (obj instanceof Double) {
                                values[j] = (Double) obj;
                            } else if (obj instanceof Boolean) {
                                values[j] = (Boolean) obj ? 1 : 0;
                            } else {
                                System.out.println("not supported");
//                                return;
                            }

                            //check weight
                            if (weight.containsKey(listOfTargetFeatures.get(j))) {
                                values[j] *= weight.get(listOfTargetFeatures.get(j));
                            }

                            //check absolute
                            if (athenaMLFeatureConfiguration.isAbsolute()) {
                                values[j] = Math.abs(values[j]);
                            }
                        }
                    }

                    Vector normedForVal;
                    if (athenaMLFeatureConfiguration.isNormalization()) {
                        normedForVal = normalizer.transform(Vectors.dense(values));
                    } else {
                        normedForVal = Vectors.dense(values);
                    }

                    LabeledPoint p = new LabeledPoint(originLabel, normedForVal);
                    //Only SVM!!

                    double prediction = model.predict(p.features());


                    lassoValidationSummary.addEntry();
                    return new Tuple2<Double, Double>(prediction, p.label());
                });

        double MSE = new JavaDoubleRDD(valuesAndPreds.map(
                new Function<Tuple2<Double, Double>, Object>() {
                    public Object call(Tuple2<Double, Double> pair) {
                        return Math.pow(pair._1() - pair._2(), 2.0);
                    }
                }
        ).rdd()).mean();
        lassoValidationSummary.setMSE(MSE);
        lassoValidationSummary.setLassoDetectionAlgorithm((LassoDetectionAlgorithm) lassoDetectionModel.getDetectionAlgorithm());
    }
 
开发者ID:shlee89,项目名称:athena,代码行数:75,代码来源:LassoDistJob.java

示例12: validate

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public void validate(JavaPairRDD<Object, BSONObject> mongoRDD,
                         AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
                         LinearRegressionDetectionModel linearRegressionDetectionModel,
                         LinearRegressionValidationSummary linearRegressionValidationSummary) {
        List<AthenaFeatureField> listOfTargetFeatures = athenaMLFeatureConfiguration.getListOfTargetFeatures();
        Map<AthenaFeatureField, Integer> weight = athenaMLFeatureConfiguration.getWeight();
        Marking marking = linearRegressionDetectionModel.getMarking();
        LinearRegressionModel model = (LinearRegressionModel) linearRegressionDetectionModel.getDetectionModel();
        Normalizer normalizer = new Normalizer();

        int numberOfTargetValue = listOfTargetFeatures.size();

        JavaRDD<Tuple2<Double, Double>> valuesAndPreds = mongoRDD.map(
                (Function<Tuple2<Object, BSONObject>, Tuple2<Double, Double>>) t -> {

                    BSONObject feature = (BSONObject) t._2().get(AthenaFeatureField.FEATURE);
                    BSONObject idx = (BSONObject) t._2();
                    int originLabel = marking.checkClassificationMarkingElements(idx, feature);

                    double[] values = new double[numberOfTargetValue];
                    for (int j = 0; j < numberOfTargetValue; j++) {
                        values[j] = 0;
                        if (feature.containsField(listOfTargetFeatures.get(j).getValue())) {
                            Object obj = feature.get(listOfTargetFeatures.get(j).getValue());
                            if (obj instanceof Long) {
                                values[j] = (Long) obj;
                            } else if (obj instanceof Double) {
                                values[j] = (Double) obj;
                            } else if (obj instanceof Boolean) {
                                values[j] = (Boolean) obj ? 1 : 0;
                            } else {
                                System.out.println("not supported");
//                                return;
                            }

                            //check weight
                            if (weight.containsKey(listOfTargetFeatures.get(j))) {
                                values[j] *= weight.get(listOfTargetFeatures.get(j));
                            }

                            //check absolute
                            if (athenaMLFeatureConfiguration.isAbsolute()) {
                                values[j] = Math.abs(values[j]);
                            }
                        }
                    }

                    Vector normedForVal;
                    if (athenaMLFeatureConfiguration.isNormalization()) {
                        normedForVal = normalizer.transform(Vectors.dense(values));
                    } else {
                        normedForVal = Vectors.dense(values);
                    }

                    LabeledPoint p = new LabeledPoint(originLabel, normedForVal);
                    //Only SVM!!

                    double prediction = model.predict(p.features());


                    linearRegressionValidationSummary.addEntry();
                    return new Tuple2<Double, Double>(prediction, p.label());
                });

        double MSE = new JavaDoubleRDD(valuesAndPreds.map(
                new Function<Tuple2<Double, Double>, Object>() {
                    public Object call(Tuple2<Double, Double> pair) {
                        return Math.pow(pair._1() - pair._2(), 2.0);
                    }
                }
        ).rdd()).mean();
        linearRegressionValidationSummary.setMSE(MSE);
        linearRegressionValidationSummary.setLinearRegressionDetectionAlgorithm((LinearRegressionDetectionAlgorithm) linearRegressionDetectionModel.getDetectionAlgorithm());
    }
 
开发者ID:shlee89,项目名称:athena,代码行数:75,代码来源:LinearRegressionDistJob.java

示例13: countObservations

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
/**
 * Action: Counts the number of observations.
 *
 * @return the number of overall observations in all time series
 */
public long countObservations() {
    JavaDoubleRDD sizesRdd = this.mapToDouble(
            (DoubleFunction<MetricTimeSeries>) value -> (double) value.size());
    return sizesRdd.sum().longValue();
}
 
开发者ID:ChronixDB,项目名称:chronix.spark,代码行数:11,代码来源:ChronixRDD.java

示例14: main

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public static void main(String[] args) throws IOException {

    JavaSparkContext sc = new JavaSparkContext("local", "WikipediaKMeans");

    JavaRDD<String> lines = sc.textFile("data/" + input_file);

    JavaRDD<LabeledPoint> points = lines.map(new ParsePoint());

    // Split initial RDD into two with 70% training data and 30% testing data (13L is a random seed):
    JavaRDD<LabeledPoint>[] splits = points.randomSplit(new double[]{0.7, 0.3}, 13L);
    JavaRDD<LabeledPoint> training = splits[0].cache();
    JavaRDD<LabeledPoint> testing = splits[1];
    training.cache();

    // Building the model
    int numIterations = 500;
    final SVMModel model =
        SVMWithSGD.train(JavaRDD.toRDD(training), numIterations);
    model.clearThreshold();
    // Evaluate model on testing examples and compute training error
    JavaRDD<Tuple2<Double, Double>> valuesAndPreds = testing.map(
        new Function<LabeledPoint, Tuple2<Double, Double>>() {
          public Tuple2<Double, Double> call(LabeledPoint point) {
            double prediction = model.predict(point.features());
            System.out.println(" ++ prediction: " + prediction + " original: " + map_to_print_original_text.get(point.features().compressed().toString()));
            return new Tuple2<Double, Double>(prediction, point.label());
          }
        }
    );

    double MSE = new JavaDoubleRDD(valuesAndPreds.map(
        new Function<Tuple2<Double, Double>, Object>() {
          public Object call(Tuple2<Double, Double> pair) {
            return Math.pow(pair._1() - pair._2(), 2.0);
          }
        }
    ).rdd()).mean();
    System.out.println("Test Data Mean Squared Error = " + MSE);

    sc.stop();
  }
 
开发者ID:mark-watson,项目名称:power-java,代码行数:42,代码来源:SvmTextClassifier.java

示例15: main

import org.apache.spark.api.java.JavaDoubleRDD; //导入依赖的package包/类
public static void main(String[] args) {
  JavaSparkContext sc = new JavaSparkContext("local", "University of Wisconson Cancer Data");

  // Load and parse the data
  String path = "data/university_of_wisconson_data_.txt";
  JavaRDD<String> data = sc.textFile(path);
  JavaRDD<LabeledPoint> parsedData = data.map(
      new Function<String, LabeledPoint>() {
        public LabeledPoint call(String line) {
          String[] features = line.split(",");
          double label = 0;
          double[] v = new double[features.length - 2];
          for (int i = 0; i < features.length - 2; i++)
            v[i] = Double.parseDouble(features[i + 1]) * 0.09;
          if (features[10].equals("2"))
            label = 0; // benign
          else
            label = 1; // malignant
          return new LabeledPoint(label, Vectors.dense(v));
        }
      }
  );
  // Split initial RDD into two with 70% training data and 30% testing data (13L is a random seed):
  JavaRDD<LabeledPoint>[] splits = parsedData.randomSplit(new double[]{0.7, 0.3}, 13L);
  JavaRDD<LabeledPoint> training = splits[0].cache();
  JavaRDD<LabeledPoint> testing = splits[1];
  training.cache();

  // Building the model
  int numIterations = 100;
  final LinearRegressionModel model =
      LinearRegressionWithSGD.train(JavaRDD.toRDD(training), numIterations);

  // Evaluate model on training examples and compute training error
  JavaRDD<Tuple2<Double, Double>> valuesAndPreds = testing.map(
      new Function<LabeledPoint, Tuple2<Double, Double>>() {
        public Tuple2<Double, Double> call(LabeledPoint point) {
          double prediction = model.predict(point.features());
          return new Tuple2<Double, Double>(prediction, point.label());
        }
      }
  );
  double MSE = new JavaDoubleRDD(valuesAndPreds.map(
      new Function<Tuple2<Double, Double>, Object>() {
        public Object call(Tuple2<Double, Double> pair) {
          return Math.pow(pair._1() - pair._2(), 2.0);
        }
      }
  ).rdd()).mean();
  System.out.println("Test Data Mean Squared Error = " + MSE);

  // Save and load model and test:
  model.save(sc.sc(), "generated_models");
  LinearRegressionModel loaded_model = LinearRegressionModel.load(sc.sc(), "generated_models");
  double[] malignant_test_data_1 = {0.81, 0.6, 0.92, 0.8, 0.55, 0.83, 0.88, 0.71, 0.81};
  System.err.println("Should be malignant (close to 1.0): " +
      testModel(loaded_model, malignant_test_data_1));
  double[] benign_test_data_1 = {0.55, 0.25, 0.34, 0.31, 0.29, 0.016, 0.51, 0.01, 0.05};
  System.err.println("Should be benign (close to 0.0): " +
      testModel(loaded_model, benign_test_data_1));
}
 
开发者ID:mark-watson,项目名称:power-java,代码行数:62,代码来源:LogisticRegression.java


注:本文中的org.apache.spark.api.java.JavaDoubleRDD类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。