当前位置: 首页>>代码示例>>Java>>正文


Java PipelineStage类代码示例

本文整理汇总了Java中org.apache.spark.ml.PipelineStage的典型用法代码示例。如果您正苦于以下问题:Java PipelineStage类的具体用法?Java PipelineStage怎么用?Java PipelineStage使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


PipelineStage类属于org.apache.spark.ml包,在下文中一共展示了PipelineStage类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: createPipeline

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
/**
 * Creates a processing pipeline.
 * @return a pipeline
 */
private Pipeline createPipeline() {
	Tokenizer tokenizer = new Tokenizer()
		.setInputCol("featureStrings")
		.setOutputCol("tokens");
	CountVectorizer countVectorizer = new CountVectorizer()
		.setInputCol("tokens")
		.setOutputCol("features")
		.setMinDF((Double)params.getOrDefault(params.getMinFF()))
		.setVocabSize((Integer)params.getOrDefault(params.getNumFeatures()));  
	StringIndexer tagIndexer = new StringIndexer()
		.setInputCol("tag")
		.setOutputCol("label");
	
	Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{tokenizer, countVectorizer, tagIndexer});
	return pipeline;
}
 
开发者ID:phuonglh,项目名称:vn.vitk,代码行数:21,代码来源:CMM.java

示例2: createPipeline

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
/**
 * Creates a processing pipeline.
 * @return a pipeline
 */
protected Pipeline createPipeline() {
	Tokenizer tokenizer = new Tokenizer()
		.setInputCol("text")
		.setOutputCol("tokens");
	CountVectorizer countVectorizer = new CountVectorizer()
		.setInputCol("tokens")
		.setOutputCol("features")
		.setMinDF((Double)params.getOrDefault(params.getMinFF()))
		.setVocabSize((Integer)params.getOrDefault(params.getNumFeatures()));  
	StringIndexer transitionIndexer = new StringIndexer()
		.setInputCol("transition")
		.setOutputCol("label");
	
	Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{tokenizer, countVectorizer, transitionIndexer});
	return pipeline;
}
 
开发者ID:phuonglh,项目名称:vn.vitk,代码行数:21,代码来源:TransitionClassifier.java

示例3: testNetwork

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testNetwork() {
    DatasetFacade df = DatasetFacade.dataRows(sqlContext.read().json("src/test/resources/dl4jnetwork"));
    Pipeline p = new Pipeline().setStages(new PipelineStage[] {getAssembler(new String[] {"x", "y"}, "features")});
    DatasetFacade part2 = DatasetFacade.dataRows(p.fit(df.get()).transform(df.get()).select("features", "label"));

    ParamSerializer ps = new ParamHelper();
    MultiLayerConfiguration mc = getNNConfiguration();
    Collection<IterationListener> il = new ArrayList<>();
    il.add(new ScoreIterationListener(1));

    SparkDl4jNetwork sparkDl4jNetwork =
                    new SparkDl4jNetwork(mc, 2, ps, 1, il, true).setFeaturesCol("features").setLabelCol("label");

    SparkDl4jModel sm = sparkDl4jNetwork.fit(part2.get());
    MultiLayerNetwork mln = sm.getMultiLayerNetwork();
    Assert.assertNotNull(mln);
    DatasetFacade transformed = DatasetFacade.dataRows(sm.transform(part2.get()));
    List<?> rows = transformed.get().collectAsList();
    Assert.assertNotNull(sm.getTrainingStats());
    Assert.assertNotNull(rows);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:SparkDl4jNetworkTest.java

示例4: testNetworkLoader

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testNetworkLoader() throws Exception {
    DatasetFacade df = DatasetFacade.dataRows(sqlContext.read().json("src/test/resources/dl4jnetwork"));
    Pipeline p = new Pipeline().setStages(new PipelineStage[] {getAssembler(new String[] {"x", "y"}, "features")});
    DatasetFacade part2 = DatasetFacade.dataRows(p.fit(df.get()).transform(df.get()).select("features", "label"));

    ParamSerializer ps = new ParamHelper();
    MultiLayerConfiguration mc = getNNConfiguration();
    Collection<IterationListener> il = new ArrayList<>();
    il.add(new ScoreIterationListener(1));

    SparkDl4jNetwork sparkDl4jNetwork =
                    new SparkDl4jNetwork(mc, 2, ps, 1, il, true).setFeaturesCol("features").setLabelCol("label");

    String fileName = UUID.randomUUID().toString();
    SparkDl4jModel sm = sparkDl4jNetwork.fit(part2.get());
    sm.write().overwrite().save(fileName);
    SparkDl4jModel spdm = SparkDl4jModel.load(fileName);
    Assert.assertNotNull(spdm);

    File file1 = new File(fileName);
    File file2 = new File(fileName + "_metadata");
    FileUtils.deleteDirectory(file1);
    FileUtils.deleteDirectory(file2);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:SparkDl4jNetworkTest.java

示例5: testAutoencoderSave

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testAutoencoderSave() throws IOException {
    DatasetFacade df = DatasetFacade.dataRows(sqlContext.read().json("src/test/resources/autoencoders"));
    Pipeline p = new Pipeline().setStages(new PipelineStage[] {
                    getAssembler(new String[] {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}, "features")});
    DatasetFacade part2 = DatasetFacade.dataRows(p.fit(df.get()).transform(df.get()).select("features"));

    AutoEncoder sparkDl4jNetwork = new AutoEncoder().setInputCol("features").setOutputCol("auto_encoded")
                    .setCompressedLayer(2).setTrainingMaster(new ParamHelper())
                    .setMultiLayerConfiguration(getNNConfiguration());

    AutoEncoderModel sm = sparkDl4jNetwork.fit(part2.get());

    String fileName = UUID.randomUUID().toString();
    sm.write().save(fileName);
    AutoEncoderModel spdm = AutoEncoderModel.load(fileName);
    Assert.assertNotNull(spdm);
    Assert.assertNotNull(spdm.transform(part2.get()));

    File file = new File(fileName);
    File file2 = new File(fileName + "_metadata");
    FileUtils.deleteDirectory(file);
    FileUtils.deleteDirectory(file2);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:25,代码来源:AutoEncoderNetworkTest.java

示例6: createPipeline

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
static
private Pipeline createPipeline(FunctionType function, String formulaString){
	RFormula formula = new RFormula()
		.setFormula(formulaString);

	Predictor<?, ?, ?> predictor;

	switch(function){
		case CLASSIFICATION:
			predictor = new DecisionTreeClassifier()
				.setMinInstancesPerNode(10);
			break;
		case REGRESSION:
			predictor = new DecisionTreeRegressor()
				.setMinInstancesPerNode(10);
			break;
		default:
			throw new IllegalArgumentException();
	}

	predictor
		.setLabelCol(formula.getLabelCol())
		.setFeaturesCol(formula.getFeaturesCol());

	Pipeline pipeline = new Pipeline()
		.setStages(new PipelineStage[]{formula, predictor});

	return pipeline;
}
 
开发者ID:jpmml,项目名称:jpmml-sparkml-bootstrap,代码行数:30,代码来源:Main.java

示例7: testNetwork

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testNetwork() {
    DatasetFacade df = DatasetFacade.dataRows(sqlContext.read().json("src/test/resources/autoencoders"));
    Pipeline p = new Pipeline().setStages(new PipelineStage[] {
                    getAssembler(new String[] {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}, "features")});
    DatasetFacade part2 = DatasetFacade.dataRows(p.fit(df.get()).transform(df.get()).select("features"));

    AutoEncoder sparkDl4jNetwork = new AutoEncoder().setInputCol("features").setOutputCol("auto_encoded")
                    .setCompressedLayer(2).setTrainingMaster(new ParamHelper())
                    .setMultiLayerConfiguration(getNNConfiguration());

    AutoEncoderModel sm = sparkDl4jNetwork.fit(part2.get());
    MultiLayerNetwork mln = sm.getNetwork();
    Assert.assertNotNull(mln);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:16,代码来源:AutoEncoderNetworkTest.java

示例8: train

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
/**
 * Trains a whitespace classifier model and save the resulting pipeline model
 * to an external file. 
 * @param sentences a list of tokenized sentences.
 * @param pipelineModelFileName
 * @param numFeatures
 */
public void train(List<String> sentences, String pipelineModelFileName, int numFeatures) {
	List<WhitespaceContext> contexts = new ArrayList<WhitespaceContext>(sentences.size());
	int id = 0;
	for (String sentence : sentences) {
		sentence = sentence.trim();
		for (int j = 0; j < sentence.length(); j++) {
			char c = sentence.charAt(j);
			if (c == ' ' || c == '_') {
				WhitespaceContext context = new WhitespaceContext();
				context.setId(id++);
				context.setContext(extractContext(sentence, j));
				context.setLabel(c == ' ' ? 0d : 1d);
				contexts.add(context);
			}
		}
	}
	JavaRDD<WhitespaceContext> jrdd = jsc.parallelize(contexts);
	DataFrame df = sqlContext.createDataFrame(jrdd, WhitespaceContext.class);
	df.show(false);
	System.out.println("N = " + df.count());
	df.groupBy("label").count().show();
	
	org.apache.spark.ml.feature.Tokenizer tokenizer = new Tokenizer()
			.setInputCol("context").setOutputCol("words");
	HashingTF hashingTF = new HashingTF().setNumFeatures(numFeatures)
			.setInputCol(tokenizer.getOutputCol()).setOutputCol("features");
	LogisticRegression lr = new LogisticRegression().setMaxIter(100)
			.setRegParam(0.01);
	Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {
			tokenizer, hashingTF, lr });
	model = pipeline.fit(df);
	
	try {
		model.write().overwrite().save(pipelineModelFileName);
	} catch (IOException e) {
		e.printStackTrace();
	}
	
	DataFrame predictions = model.transform(df);
	predictions.show();
	MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setMetricName("precision");
	double accuracy = evaluator.evaluate(predictions);
	System.out.println("training accuracy = " + accuracy);
	
	LogisticRegressionModel lrModel = (LogisticRegressionModel) model.stages()[2];
	LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
	double[] objectiveHistory = trainingSummary.objectiveHistory();
	System.out.println("#(iterations) = " + objectiveHistory.length);
	for (double lossPerIteration : objectiveHistory) {
	  System.out.println(lossPerIteration);
	}
	
}
 
开发者ID:phuonglh,项目名称:vn.vitk,代码行数:61,代码来源:WhitespaceClassifier.java

示例9: testDecisionTreeRegressionPrediction

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
  public void testDecisionTreeRegressionPrediction() {
      // Load the data stored in LIBSVM format as a DataFrame.
  	String datapath = "src/test/resources/regression_test.libsvm";
  	
  	Dataset<Row> data = spark.read().format("libsvm").load(datapath);


      // Split the data into training and test sets (30% held out for testing)
      Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
      Dataset<Row> trainingData = splits[0];
      Dataset<Row> testData = splits[1];

      StringIndexer indexer = new StringIndexer()
              .setInputCol("label")
              .setOutputCol("labelIndex").setHandleInvalid("skip");
      
DecisionTreeRegressor regressionModel =
        new DecisionTreeRegressor().setLabelCol("labelIndex").setFeaturesCol("features");

Pipeline pipeline = new Pipeline()
              .setStages(new PipelineStage[]{indexer, regressionModel});

PipelineModel sparkPipeline = pipeline.fit(trainingData);

      byte[] exportedModel = ModelExporter.export(sparkPipeline);

      Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
      List<Row> output = sparkPipeline.transform(testData).select("features", "prediction", "label").collectAsList();

      //compare predictions
      for (Row row : output) {
      	Map<String, Object> data_ = new HashMap<>();
          data_.put("features", ((SparseVector) row.get(0)).toArray());
          data_.put("label", (row.get(2)).toString());
          transformer.transform(data_);
          System.out.println(data_);
          System.out.println(data_.get("prediction"));
          assertEquals((double)data_.get("prediction"), (double)row.get(1), EPSILON);
      }
  }
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:DecisionTreeRegressionModelBridgePipelineTest.java

示例10: testGradientBoostClassification

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testGradientBoostClassification() {
	// Load the data stored in LIBSVM format as a DataFrame.
	String datapath = "src/test/resources/binary_classification_test.libsvm";

	Dataset<Row> data = spark.read().format("libsvm").load(datapath);
	StringIndexer indexer = new StringIndexer()
               .setInputCol("label")
               .setOutputCol("labelIndex");
	// Split the data into training and test sets (30% held out for testing)
	Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
	Dataset<Row> trainingData = splits[0];
	Dataset<Row> testData = splits[1];

	// Train a RandomForest model.
	GBTClassifier classificationModel = new GBTClassifier().setLabelCol("labelIndex")
               .setFeaturesCol("features");;

        Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{indexer, classificationModel});


	 PipelineModel sparkPipeline = pipeline.fit(trainingData);

	// Export this model
	byte[] exportedModel = ModelExporter.export(sparkPipeline);

	// Import and get Transformer
	Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

	List<Row> sparkOutput = sparkPipeline.transform(testData).select("features", "prediction", "label").collectAsList();
	
	// compare predictions
	for (Row row : sparkOutput) {
		Map<String, Object> data_ = new HashMap<>();
		data_.put("features", ((SparseVector) row.get(0)).toArray());
		data_.put("label", (row.get(2)).toString());
		transformer.transform(data_);
		System.out.println(data_);
		System.out.println(data_.get("prediction")+" ,"+row.get(1));
		assertEquals((double) data_.get("prediction"), (double) row.get(1), EPSILON);
	}

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:45,代码来源:GradientBoostClassificationModelPipelineTest.java

示例11: testDecisionTreeClassificationWithPipeline

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testDecisionTreeClassificationWithPipeline() {
	

    // Load the data stored in LIBSVM format as a DataFrame.
	String datapath = "src/test/resources/classification_test.libsvm";
	Dataset<Row> data = spark.read().format("libsvm").load(datapath);



    // Split the data into training and test sets (30% held out for testing)
    Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});        

    Dataset<Row> trainingData = splits[0];
    Dataset<Row> testData = splits[1];

    StringIndexer indexer = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex");

    // Train a DecisionTree model.
    DecisionTreeClassifier classificationModel = new DecisionTreeClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features");

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{indexer, classificationModel});


    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    List<Row> output = sparkPipeline.transform(testData).select("features", "label","prediction","rawPrediction").collectAsList();

    //compare predictions
    for (Row row : output) {
    	Map<String, Object> data_ = new HashMap<>();
    	double [] actualRawPrediction = ((DenseVector) row.get(3)).toArray();
        data_.put("features", ((SparseVector) row.get(0)).toArray());
        data_.put("label", (row.get(1)).toString());
        transformer.transform(data_);
        System.out.println(data_);
        System.out.println(data_.get("prediction"));
        assertEquals((double)data_.get("prediction"), (double)row.get(2), EPSILON);
        assertArrayEquals((double[]) data_.get("rawPrediction"), actualRawPrediction, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:54,代码来源:DecisionTreeClassificationModelBridgePipelineTest.java

示例12: testPipeline

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testPipeline() {
    // Prepare training documents, which are labeled.
    StructType schema = createStructType(new StructField[]{
            createStructField("id", LongType, false),
            createStructField("text", StringType, false),
            createStructField("label", DoubleType, false)
    });
    Dataset<Row> trainingData = spark.createDataFrame(Arrays.asList(
            cr(0L, "a b c d e spark", 1.0),
            cr(1L, "b d", 0.0),
            cr(2L, "spark f g h", 1.0),
            cr(3L, "hadoop mapreduce", 0.0)
    ), schema);

    // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and LogisticRegression.
    RegexTokenizer tokenizer = new RegexTokenizer()
            .setInputCol("text")
            .setOutputCol("words")
            .setPattern("\\s")
            .setGaps(true)
            .setToLowercase(false);

    HashingTF hashingTF = new HashingTF()
            .setNumFeatures(1000)
            .setInputCol(tokenizer.getOutputCol())
            .setOutputCol("features");
    LogisticRegression lr = new LogisticRegression()
            .setMaxIter(10)
            .setRegParam(0.01);
    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{tokenizer, hashingTF, lr});

    // Fit the pipeline to training documents.
    PipelineModel sparkPipelineModel = pipeline.fit(trainingData);


    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipelineModel);
    System.out.println(new String(exportedModel));

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //prepare test data
    StructType testSchema = createStructType(new StructField[]{
            createStructField("id", LongType, false),
            createStructField("text", StringType, false),
    });
    Dataset<Row> testData = spark.createDataFrame(Arrays.asList(
            cr(4L, "spark i j k"),
            cr(5L, "l m n"),
            cr(6L, "mapreduce spark"),
            cr(7L, "apache hadoop")
    ), testSchema);

    //verify that predictions for spark pipeline and exported pipeline are the same
    List<Row> predictions = sparkPipelineModel.transform(testData).select("id", "text", "probability", "prediction").collectAsList();
    for (Row r : predictions) {
        System.out.println(r);
        double sparkPipelineOp = r.getDouble(3);
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("text", r.getString(1));
        transformer.transform(data);
        double exportedPipelineOp = (double) data.get("prediction");
        double exportedPipelineProb = (double) data.get("probability");
        assertEquals(sparkPipelineOp, exportedPipelineOp, 0.01);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:70,代码来源:PipelineBridgeTest.java

示例13: testRandomForestRegressionWithPipeline

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testRandomForestRegressionWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a RandomForest model.
    RandomForestRegressionModel regressionModel = new RandomForestRegressor()
            .setFeaturesCol("features").fit(trainingData);

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{regressionModel});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:RandomForestRegressionModelInfoAdapterBridgeTest.java

示例14: testDecisionTreeRegressionWithPipeline

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testDecisionTreeRegressionWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a DecisionTree model.
    DecisionTreeRegressor dt = new DecisionTreeRegressor()
            .setFeaturesCol("features");

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{dt});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:DecisionTreeRegressionModelBridgeTest.java

示例15: testPipeline

import org.apache.spark.ml.PipelineStage; //导入依赖的package包/类
@Test
public void testPipeline() {
    // Prepare training documents, which are labeled.
    StructType schema = createStructType(new StructField[]{
            createStructField("id", LongType, false),
            createStructField("text", StringType, false),
            createStructField("label", DoubleType, false)
    });
    DataFrame trainingData = sqlContext.createDataFrame(Arrays.asList(
            cr(0L, "a b c d e spark", 1.0),
            cr(1L, "b d", 0.0),
            cr(2L, "spark f g h", 1.0),
            cr(3L, "hadoop mapreduce", 0.0)
    ), schema);

    // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and LogisticRegression.
    RegexTokenizer tokenizer = new RegexTokenizer()
            .setInputCol("text")
            .setOutputCol("words")
            .setPattern("\\s")
            .setGaps(true)
            .setToLowercase(false);

    HashingTF hashingTF = new HashingTF()
            .setNumFeatures(1000)
            .setInputCol(tokenizer.getOutputCol())
            .setOutputCol("features");
    LogisticRegression lr = new LogisticRegression()
            .setMaxIter(10)
            .setRegParam(0.01);
    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{tokenizer, hashingTF, lr});

    // Fit the pipeline to training documents.
    PipelineModel sparkPipelineModel = pipeline.fit(trainingData);


    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipelineModel, trainingData);
    System.out.println(new String(exportedModel));

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //prepare test data
    StructType testSchema = createStructType(new StructField[]{
            createStructField("id", LongType, false),
            createStructField("text", StringType, false),
    });
    DataFrame testData = sqlContext.createDataFrame(Arrays.asList(
            cr(4L, "spark i j k"),
            cr(5L, "l m n"),
            cr(6L, "mapreduce spark"),
            cr(7L, "apache hadoop")
    ), testSchema);

    //verify that predictions for spark pipeline and exported pipeline are the same
    Row[] predictions = sparkPipelineModel.transform(testData).select("id", "text", "probability", "prediction").collect();
    for (Row r : predictions) {
        System.out.println(r);
        double sparkPipelineOp = r.getDouble(3);
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("text", r.getString(1));
        transformer.transform(data);
        double exportedPipelineOp = (double) data.get("prediction");
        double exportedPipelineProb = (double) data.get("probability");
        assertEquals(sparkPipelineOp, exportedPipelineOp, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:70,代码来源:PipelineBridgeTest.java


注:本文中的org.apache.spark.ml.PipelineStage类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。