当前位置: 首页>>代码示例>>Java>>正文


Java KMeans.train方法代码示例

本文整理汇总了Java中org.apache.spark.mllib.clustering.KMeans.train方法的典型用法代码示例。如果您正苦于以下问题:Java KMeans.train方法的具体用法?Java KMeans.train怎么用?Java KMeans.train使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.apache.spark.mllib.clustering.KMeans的用法示例。


在下文中一共展示了KMeans.train方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: buildModel

import org.apache.spark.mllib.clustering.KMeans; //导入方法依赖的package包/类
/**
 * @param sparkContext    active Spark Context
 * @param trainData       training data on which to build a model
 * @param hyperParameters ordered list of hyper parameter values to use in building model
 * @param candidatePath   directory where additional model files can be written
 * @return a {@link PMML} representation of a model trained on the given data
 */
@Override
public PMML buildModel(JavaSparkContext sparkContext,
                       JavaRDD<String> trainData,
                       List<?> hyperParameters,
                       Path candidatePath) {
  int numClusters = (Integer) hyperParameters.get(0);
  Preconditions.checkArgument(numClusters > 1);
  log.info("Building KMeans Model with {} clusters", numClusters);

  JavaRDD<Vector> trainingData = parsedToVectorRDD(trainData.map(MLFunctions.PARSE_FN));
  KMeansModel kMeansModel = KMeans.train(trainingData.rdd(), numClusters, maxIterations,
                                         numberOfRuns, initializationStrategy);

  return kMeansModelToPMML(kMeansModel, fetchClusterCountsFromModel(trainingData, kMeansModel));
}
 
开发者ID:oncewang,项目名称:oryx2,代码行数:23,代码来源:KMeansUpdate.java

示例2: main

import org.apache.spark.mllib.clustering.KMeans; //导入方法依赖的package包/类
public static void main(String[] args) {

    String inputFile = "data/kmeans_data.txt";
    int k = 2; // two clusters
    int iterations = 10;
    int runs = 1;

    JavaSparkContext sc = new JavaSparkContext("local", "JavaKMeans");
    JavaRDD<String> lines = sc.textFile(inputFile);

    JavaRDD<Vector> points = lines.map(new ParsePoint());

    KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL());

    System.out.println("Cluster centers:");
    for (Vector center : model.clusterCenters()) {
      System.out.println(" " + center);
    }
    double cost = model.computeCost(points.rdd());
    System.out.println("Cost: " + cost);

    sc.stop();
  }
 
开发者ID:mark-watson,项目名称:power-java,代码行数:24,代码来源:JavaKMeans.java

示例3: main

import org.apache.spark.mllib.clustering.KMeans; //导入方法依赖的package包/类
public static void main(String[] args) {
	SparkConf conf = new SparkConf().setMaster("local").setAppName("SparkStreamsSampleTrainingApplication");
	JavaSparkContext jsc = new JavaSparkContext(conf);
	
	JavaRDD<String> lines = jsc.textFile("data/random_2d_training.csv");
	JavaRDD<Vector> parsedData = lines.map(
      new Function<String, Vector>() {
		@Override
		public Vector call(String s) {
		    String[] sarray = s.split(",");
	          double[] values = new double[sarray.length];
	          for (int i = 0; i < sarray.length; i++) {
	            values[i] = Double.parseDouble(sarray[i]);
	          }
	          return Vectors.dense(values);
		}
      }
    );
	parsedData.cache();
	
    int numClusters = 10;
    int numIterations = 20;
    KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations);
    clusters.save(jsc.sc(), "etc/kmeans_model");
    jsc.close();
}
 
开发者ID:IBMStreams,项目名称:streamsx.sparkMLLib,代码行数:27,代码来源:JavaTrainingApplication.java

示例4: main

import org.apache.spark.mllib.clustering.KMeans; //导入方法依赖的package包/类
public static void main( String[] args ){
	SparkConf conf = new SparkConf().setMaster("local[4]").setAppName("K-means Example");
    JavaSparkContext sc = new JavaSparkContext(conf);

    // Load and parse data
    String path = "data/km-data.txt";
    JavaRDD<String> data = sc.textFile(path);
    JavaRDD<Vector> parsedData = data.map(
      new Function<String, Vector>() {
        public Vector call(String s) {
          String[] sarray = s.split(" ");
          double[] values = new double[sarray.length];
          for (int i = 0; i < sarray.length; i++)
            values[i] = Double.parseDouble(sarray[i]);
          return Vectors.dense(values);
        }
      }
    );
    parsedData.cache();

    // Cluster the data into two classes using KMeans
    int numClusters = 2;
    int numIterations = 20;
    KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations);

    // Evaluate clustering by computing Within Set Sum of Squared Errors
    double WSSSE = clusters.computeCost(parsedData.rdd());
    System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
	
	
	
}
 
开发者ID:PacktPublishing,项目名称:Java-Data-Science-Cookbook,代码行数:33,代码来源:KMeansClusteringMlib.java

示例5: main

import org.apache.spark.mllib.clustering.KMeans; //导入方法依赖的package包/类
public static void main(String[] args) throws IOException {

    int number_of_clusters = 8;
    int iterations = 100;
    int runs = 1;

    JavaSparkContext sc = new JavaSparkContext("local", "WikipediaKMeans");

    JavaRDD<String> lines = sc.textFile("data/" + input_file);

    JavaRDD<Vector> points = lines.map(new ParsePoint());

    KMeansModel model = KMeans.train(points.rdd(), number_of_clusters, iterations, runs, KMeans.K_MEANS_PARALLEL());

    System.out.println("Cluster centers:");
    for (Vector center : model.clusterCenters()) {
      System.out.println("\n " + center);
      String [] bestWords = sparseVectorGenerator.bestWords(center.toArray());
      System.out.println(" bestWords: " + Arrays.asList(bestWords));
    }
    double cost = model.computeCost(points.rdd());
    System.out.println("Cost: " + cost);

    // Print out documents by cluster index. Note: this is really inefficient
    // because I am cycling through the input file number_of_clusters times.
    // In a normal application the cluster index for each document would be saved
    // as metadata for each document. So, please consider the following loop and
    // the method printClusterIndex to be only for pretty-printing the results
    // of this example program:
    for (int i=0; i<number_of_clusters; i++)
      printClusterIndex(i, model);

    sc.stop();
  }
 
开发者ID:mark-watson,项目名称:power-java,代码行数:35,代码来源:WikipediaKMeans.java

示例6: main

import org.apache.spark.mllib.clustering.KMeans; //导入方法依赖的package包/类
public static void main(String[] args) {
    if (args.length < 2) {
        System.err.println(
                "Usage: KMeansMP <input_file> <results>");
        System.exit(1);
    }
    String inputFile = args[0];
    String results_path = args[1];
    JavaPairRDD<Integer, Iterable<String>> results;
    int k = 4;
    int iterations = 100;
    int runs = 1;
    long seed = 0;
    final KMeansModel model;

    SparkConf sparkConf = new SparkConf().setAppName("KMeans MP");
    JavaSparkContext sc = new JavaSparkContext(sparkConf);

    JavaRDD<String> lines = sc.textFile(inputFile);
    JavaRDD<Vector> points = lines.map(new ParsePoint());
    JavaRDD<String> titles = lines.map(new ParseTitle());
    model = KMeans.train(points.rdd(), k, iterations, runs, RANDOM(), seed);
    results = titles.zip(points).mapToPair(new ClusterCars(model)).groupByKey();
    results.foreach(new PrintCluster());
    results.saveAsTextFile(results_path);

    sc.stop();
}
 
开发者ID:kgrodzicki,项目名称:cloud-computing-specialization,代码行数:29,代码来源:KMeansMP.java

示例7: train

import org.apache.spark.mllib.clustering.KMeans; //导入方法依赖的package包/类
@Override
public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException {
  List<String> features = AlgoArgParser.parseArgs(this, params);
  final int[] featurePositions = new int[features.size()];
  final int NUM_FEATURES = features.size();

  JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
  try {
    // Map feature names to positions
    Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table);
    List<FieldSchema> allCols = tbl.getAllCols();
    int f = 0;
    for (int i = 0; i < tbl.getAllCols().size(); i++) {
      String colName = allCols.get(i).getName();
      if (features.contains(colName)) {
        featurePositions[f++] = i;
      }
    }

    rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter);
    JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
      @Override
      public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
        HCatRecord hCatRecord = v1._2();
        double[] arr = new double[NUM_FEATURES];
        for (int i = 0; i < NUM_FEATURES; i++) {
          Object val = hCatRecord.get(featurePositions[i]);
          arr[i] = val == null ? 0d : (Double) val;
        }
        return Vectors.dense(arr);
      }
    });

    KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
    return new KMeansClusteringModel(modelId, model);
  } catch (Exception e) {
    throw new LensException("KMeans algo failed for " + db + "." + table, e);
  }
}
 
开发者ID:apache,项目名称:lens,代码行数:40,代码来源:KMeansAlgo.java

示例8: main

import org.apache.spark.mllib.clustering.KMeans; //导入方法依赖的package包/类
public static void main(String[] args) {
  if (args.length < 3) {
    System.err.println(
      "Usage: JavaKMeans <input_file> <k> <max_iterations> [<runs>]");
    System.exit(1);
  }
  String inputFile = args[0];
  int k = Integer.parseInt(args[1]);
  int iterations = Integer.parseInt(args[2]);
  int runs = 1;

  if (args.length >= 4) {
    runs = Integer.parseInt(args[3]);
  }
  SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans");
  JavaSparkContext sc = new JavaSparkContext(sparkConf);

  JavaPairRDD<LongWritable, VectorWritable> data = sc.sequenceFile(inputFile,
              LongWritable.class, VectorWritable.class);

  JavaRDD<Vector> points =
      data.map(new Function<Tuple2<LongWritable, VectorWritable>, Vector>() {
      @Override
      public Vector call(Tuple2<LongWritable, VectorWritable> e) {
          VectorWritable val = e._2();
          double[] v = new double[val.get().size()];
          for (int i = 0; i < val.get().size(); ++i) {
              v[i] = val.get().get(i);
          }
          return Vectors.dense(v);
      }
  }).cache();

  KMeansModel model = KMeans
      .train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL());

  System.out.println("Cluster centers:");
  for (Vector center : model.clusterCenters()) {
    System.out.println(" " + center);
  }
  double cost = model.computeCost(points.rdd());
  System.out.println("Cost: " + cost);

  sc.stop();
}
 
开发者ID:thrill,项目名称:fst-bench,代码行数:46,代码来源:JavaKMeans.java

示例9: main

import org.apache.spark.mllib.clustering.KMeans; //导入方法依赖的package包/类
public static void main(String[] args) {
    SparkConf conf = new SparkConf().setAppName("JavaKMeansExample").setMaster("local[4]");
    JavaSparkContext jsc = new JavaSparkContext(conf);

    // $example on$
    // Load and parse data
    String path = "src/main/resources/kmeans-data.txt";
    JavaRDD<String> data = jsc.textFile(path);
    JavaRDD<Vector> parsedData = data.map(
            new Function<String, Vector>() {
                public Vector call(String s) {
                    String[] sarray = s.split(" ");
                    double[] values = new double[sarray.length];
                    for (int i = 0; i < sarray.length; i++) {
                        values[i] = Double.parseDouble(sarray[i]);
                    }
                    return Vectors.dense(values);
                }
            }
    );
    parsedData.cache();

    // Cluster the data into two classes using KMeans
    int numClusters = 2;
    int numIterations = 20;
    KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations);

    System.out.println("Cluster centers:");
    for (Vector center: clusters.clusterCenters()) {
        System.out.println(" " + center);
    }
    double cost = clusters.computeCost(parsedData.rdd());
    System.out.println("Cost: " + cost);

    // Evaluate clustering by computing Within Set Sum of Squared Errors
    double WSSSE = clusters.computeCost(parsedData.rdd());
    System.out.println("Within Set Sum of Squared Errors = " + WSSSE);

    // Save and load model
    clusters.save(jsc.sc(), "target/org/apache/spark/JavaKMeansExample/KMeansModel");
    KMeansModel sameModel = KMeansModel.load(jsc.sc(),
            "target/org/apache/spark/JavaKMeansExample/KMeansModel");
    // $example off$

    jsc.stop();
}
 
开发者ID:knoldus,项目名称:Sparkathon,代码行数:47,代码来源:Kmeans.java


注:本文中的org.apache.spark.mllib.clustering.KMeans.train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。