当前位置: 首页>>代码示例>>Java>>正文


Java DataFrame.randomSplit方法代码示例

本文整理汇总了Java中org.apache.spark.sql.DataFrame.randomSplit方法的典型用法代码示例。如果您正苦于以下问题:Java DataFrame.randomSplit方法的具体用法?Java DataFrame.randomSplit怎么用?Java DataFrame.randomSplit使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.apache.spark.sql.DataFrame的用法示例。


在下文中一共展示了DataFrame.randomSplit方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testRandomSplit

import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
void testRandomSplit(String inputFileName, int numFeatures, String modelFileName) {
	CMMParams params = new CMMParams()
		.setMaxIter(600)
		.setRegParam(1E-6)
		.setMarkovOrder(2)
		.setNumFeatures(numFeatures);
	
	JavaRDD<String> lines = jsc.textFile(inputFileName);
	DataFrame dataset = createDataFrame(lines.collect());
	DataFrame[] splits = dataset.randomSplit(new double[]{0.9, 0.1}); 
	DataFrame trainingData = splits[0];
	System.out.println("Number of training sequences = " + trainingData.count());
	DataFrame testData = splits[1];
	System.out.println("Number of test sequences = " + testData.count());
	// train and save a model on the training data
	cmmModel = train(trainingData, modelFileName, params);
	// test the model on the test data
	System.out.println("Test accuracy:");
	evaluate(testData); 
	// test the model on the training data
	System.out.println("Training accuracy:");
	evaluate(trainingData);
}
 
开发者ID:phuonglh,项目名称:vn.vitk,代码行数:24,代码来源:Tagger.java

示例2: testRandomForestRegression

import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testRandomForestRegression() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a RandomForest model.
    RandomForestRegressionModel regressionModel = new RandomForestRegressor()
            .setFeaturesCol("features").fit(trainingData);

    byte[] exportedModel = ModelExporter.export(regressionModel, null);

    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = regressionModel.transform(testData).select("features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());

        System.out.println(actual + ", " + predicted);
        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:35,代码来源:RandomForestRegressionModelInfoAdapterBridgeTest.java

示例3: testDecisionTreeRegression

import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testDecisionTreeRegression() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a DecisionTree model.
    DecisionTreeRegressionModel regressionModel = new DecisionTreeRegressor()
            .setFeaturesCol("features").fit(trainingData);

    byte[] exportedModel = ModelExporter.export(regressionModel, null);

    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = regressionModel.transform(testData).select("features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());

        System.out.println(actual + ", " + predicted);
        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:35,代码来源:DecisionTreeRegressionModelBridgeTest.java

示例4: testRandomForestClassification

import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testRandomForestClassification() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");

    StringIndexerModel stringIndexerModel = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex")
            .fit(data);

    data = stringIndexerModel.transform(data);

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a RandomForest model.
    RandomForestClassificationModel classificationModel = new RandomForestClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features")
            .setPredictionCol("prediction")
            .setRawPredictionCol("rawPrediction")
            .setProbabilityCol("probability")
            .fit(trainingData);


    byte[] exportedModel = ModelExporter.export(classificationModel, null);

    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = classificationModel.transform(testData).select("features", "prediction", "rawPrediction", "probability").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);
        double [] actualProbability = ((Vector) row.get(3)).toArray();
        double[] actualRaw = ((Vector) row.get(2)).toArray();

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get("prediction");
        double[] probability = (double[]) inputData.get("probability");
        double[] rawPrediction = (double[]) inputData.get("rawPrediction");

        assertEquals(actual, predicted, EPSILON);
        assertArrayEquals(actualProbability, probability, EPSILON);
        assertArrayEquals(actualRaw, rawPrediction, EPSILON);


    }

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:56,代码来源:RandomForestClassificationModelInfoAdapterBridgeTest.java

示例5: testRandomForestRegressionWithPipeline

import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testRandomForestRegressionWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a RandomForest model.
    RandomForestRegressionModel regressionModel = new RandomForestRegressor()
            .setFeaturesCol("features").fit(trainingData);

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{regressionModel});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:RandomForestRegressionModelInfoAdapterBridgeTest.java

示例6: testDecisionTreeRegressionWithPipeline

import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testDecisionTreeRegressionWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a DecisionTree model.
    DecisionTreeRegressor dt = new DecisionTreeRegressor()
            .setFeaturesCol("features");

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{dt});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:DecisionTreeRegressionModelBridgeTest.java

示例7: testRandomForestClassificationWithPipeline

import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testRandomForestClassificationWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    StringIndexer indexer = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex");

    // Train a DecisionTree model.
    RandomForestClassifier classifier = new RandomForestClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features")
            .setPredictionCol("prediction")
            .setRawPredictionCol("rawPrediction")
            .setProbabilityCol("probability");


    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{indexer, classifier});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("label", "features", "prediction", "rawPrediction", "probability").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(1);
        double actual = row.getDouble(2);
        double [] actualProbability = ((Vector) row.get(4)).toArray();
        double[] actualRaw = ((Vector) row.get(3)).toArray();

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put("features", v.toArray());
        inputData.put("label", row.get(0).toString());
        transformer.transform(inputData);
        double predicted = (double) inputData.get("prediction");
        double[] probability = (double[]) inputData.get("probability");
        double[] rawPrediction = (double[]) inputData.get("rawPrediction");

        assertEquals(actual, predicted, EPSILON);
        assertArrayEquals(actualProbability, probability, EPSILON);
        assertArrayEquals(actualRaw, rawPrediction, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:58,代码来源:RandomForestClassificationModelInfoAdapterBridgeTest.java

示例8: testDecisionTreeClassificationRawPrediction

import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testDecisionTreeClassificationRawPrediction() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");

    StringIndexerModel stringIndexerModel = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex")
            .fit(data);

    data = stringIndexerModel.transform(data);

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a DecisionTree model.
    DecisionTreeClassificationModel classificationModel = new DecisionTreeClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features")
            .setRawPredictionCol("rawPrediction")
            .setPredictionCol("prediction")
            .fit(trainingData);

    byte[] exportedModel = ModelExporter.export(classificationModel, null);

    Transformer transformer = (DecisionTreeTransformer) ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = classificationModel.transform(testData).select("features", "prediction", "rawPrediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector inp = (Vector) row.get(0);
        double actual = row.getDouble(1);
        double[] actualRaw = ((Vector) row.get(2)).toArray();

        Map<String, Object> inputData = new HashMap<>();
        inputData.put(transformer.getInputKeys().iterator().next(), inp.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());
        double[] rawPrediction = (double[]) inputData.get("rawPrediction");

        assertEquals(actual, predicted, EPSILON);
        assertArrayEquals(actualRaw, rawPrediction, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:48,代码来源:DecisionTreeClassificationModelBridgeTest.java

示例9: testDecisionTreeClassificationWithPipeline

import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testDecisionTreeClassificationWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    StringIndexer indexer = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex");

    // Train a DecisionTree model.
    DecisionTreeClassifier classificationModel = new DecisionTreeClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features");

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{indexer, classificationModel});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("label", "features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(1);
        double actual = row.getDouble(2);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put("features", v.toArray());
        inputData.put("label", row.get(0).toString());
        transformer.transform(inputData);
        double predicted = (double) inputData.get("prediction");

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:48,代码来源:DecisionTreeClassificationModelBridgeTest.java


注:本文中的org.apache.spark.sql.DataFrame.randomSplit方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。