本文整理汇总了Java中org.apache.spark.sql.DataFrame.randomSplit方法的典型用法代码示例。如果您正苦于以下问题:Java DataFrame.randomSplit方法的具体用法?Java DataFrame.randomSplit怎么用?Java DataFrame.randomSplit使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.apache.spark.sql.DataFrame
的用法示例。
在下文中一共展示了DataFrame.randomSplit方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: testRandomSplit
import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
void testRandomSplit(String inputFileName, int numFeatures, String modelFileName) {
CMMParams params = new CMMParams()
.setMaxIter(600)
.setRegParam(1E-6)
.setMarkovOrder(2)
.setNumFeatures(numFeatures);
JavaRDD<String> lines = jsc.textFile(inputFileName);
DataFrame dataset = createDataFrame(lines.collect());
DataFrame[] splits = dataset.randomSplit(new double[]{0.9, 0.1});
DataFrame trainingData = splits[0];
System.out.println("Number of training sequences = " + trainingData.count());
DataFrame testData = splits[1];
System.out.println("Number of test sequences = " + testData.count());
// train and save a model on the training data
cmmModel = train(trainingData, modelFileName, params);
// test the model on the test data
System.out.println("Test accuracy:");
evaluate(testData);
// test the model on the training data
System.out.println("Training accuracy:");
evaluate(trainingData);
}
示例2: testRandomForestRegression
import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testRandomForestRegression() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
// Train a RandomForest model.
RandomForestRegressionModel regressionModel = new RandomForestRegressor()
.setFeaturesCol("features").fit(trainingData);
byte[] exportedModel = ModelExporter.export(regressionModel, null);
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = regressionModel.transform(testData).select("features", "prediction").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector v = (Vector) row.get(0);
double actual = row.getDouble(1);
Map<String, Object> inputData = new HashMap<String, Object>();
inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
transformer.transform(inputData);
double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());
System.out.println(actual + ", " + predicted);
assertEquals(actual, predicted, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:35,代码来源:RandomForestRegressionModelInfoAdapterBridgeTest.java
示例3: testDecisionTreeRegression
import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testDecisionTreeRegression() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
// Train a DecisionTree model.
DecisionTreeRegressionModel regressionModel = new DecisionTreeRegressor()
.setFeaturesCol("features").fit(trainingData);
byte[] exportedModel = ModelExporter.export(regressionModel, null);
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = regressionModel.transform(testData).select("features", "prediction").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector v = (Vector) row.get(0);
double actual = row.getDouble(1);
Map<String, Object> inputData = new HashMap<String, Object>();
inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
transformer.transform(inputData);
double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());
System.out.println(actual + ", " + predicted);
assertEquals(actual, predicted, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:35,代码来源:DecisionTreeRegressionModelBridgeTest.java
示例4: testRandomForestClassification
import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testRandomForestClassification() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");
StringIndexerModel stringIndexerModel = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(data);
data = stringIndexerModel.transform(data);
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
// Train a RandomForest model.
RandomForestClassificationModel classificationModel = new RandomForestClassifier()
.setLabelCol("labelIndex")
.setFeaturesCol("features")
.setPredictionCol("prediction")
.setRawPredictionCol("rawPrediction")
.setProbabilityCol("probability")
.fit(trainingData);
byte[] exportedModel = ModelExporter.export(classificationModel, null);
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = classificationModel.transform(testData).select("features", "prediction", "rawPrediction", "probability").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector v = (Vector) row.get(0);
double actual = row.getDouble(1);
double [] actualProbability = ((Vector) row.get(3)).toArray();
double[] actualRaw = ((Vector) row.get(2)).toArray();
Map<String, Object> inputData = new HashMap<String, Object>();
inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
transformer.transform(inputData);
double predicted = (double) inputData.get("prediction");
double[] probability = (double[]) inputData.get("probability");
double[] rawPrediction = (double[]) inputData.get("rawPrediction");
assertEquals(actual, predicted, EPSILON);
assertArrayEquals(actualProbability, probability, EPSILON);
assertArrayEquals(actualRaw, rawPrediction, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:56,代码来源:RandomForestClassificationModelInfoAdapterBridgeTest.java
示例5: testRandomForestRegressionWithPipeline
import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testRandomForestRegressionWithPipeline() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
// Train a RandomForest model.
RandomForestRegressionModel regressionModel = new RandomForestRegressor()
.setFeaturesCol("features").fit(trainingData);
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{regressionModel});
// Train model. This also runs the indexer.
PipelineModel sparkPipeline = pipeline.fit(trainingData);
//Export this model
byte[] exportedModel = ModelExporter.export(sparkPipeline, null);
//Import and get Transformer
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = sparkPipeline.transform(testData).select("features", "prediction").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector v = (Vector) row.get(0);
double actual = row.getDouble(1);
Map<String, Object> inputData = new HashMap<String, Object>();
inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
transformer.transform(inputData);
double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());
assertEquals(actual, predicted, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:RandomForestRegressionModelInfoAdapterBridgeTest.java
示例6: testDecisionTreeRegressionWithPipeline
import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testDecisionTreeRegressionWithPipeline() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
// Train a DecisionTree model.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
.setFeaturesCol("features");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{dt});
// Train model. This also runs the indexer.
PipelineModel sparkPipeline = pipeline.fit(trainingData);
//Export this model
byte[] exportedModel = ModelExporter.export(sparkPipeline, null);
//Import and get Transformer
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = sparkPipeline.transform(testData).select("features", "prediction").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector v = (Vector) row.get(0);
double actual = row.getDouble(1);
Map<String, Object> inputData = new HashMap<String, Object>();
inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
transformer.transform(inputData);
double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());
assertEquals(actual, predicted, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:DecisionTreeRegressionModelBridgeTest.java
示例7: testRandomForestClassificationWithPipeline
import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testRandomForestClassificationWithPipeline() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex");
// Train a DecisionTree model.
RandomForestClassifier classifier = new RandomForestClassifier()
.setLabelCol("labelIndex")
.setFeaturesCol("features")
.setPredictionCol("prediction")
.setRawPredictionCol("rawPrediction")
.setProbabilityCol("probability");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{indexer, classifier});
// Train model. This also runs the indexer.
PipelineModel sparkPipeline = pipeline.fit(trainingData);
//Export this model
byte[] exportedModel = ModelExporter.export(sparkPipeline, null);
//Import and get Transformer
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = sparkPipeline.transform(testData).select("label", "features", "prediction", "rawPrediction", "probability").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector v = (Vector) row.get(1);
double actual = row.getDouble(2);
double [] actualProbability = ((Vector) row.get(4)).toArray();
double[] actualRaw = ((Vector) row.get(3)).toArray();
Map<String, Object> inputData = new HashMap<String, Object>();
inputData.put("features", v.toArray());
inputData.put("label", row.get(0).toString());
transformer.transform(inputData);
double predicted = (double) inputData.get("prediction");
double[] probability = (double[]) inputData.get("probability");
double[] rawPrediction = (double[]) inputData.get("rawPrediction");
assertEquals(actual, predicted, EPSILON);
assertArrayEquals(actualProbability, probability, EPSILON);
assertArrayEquals(actualRaw, rawPrediction, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:58,代码来源:RandomForestClassificationModelInfoAdapterBridgeTest.java
示例8: testDecisionTreeClassificationRawPrediction
import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testDecisionTreeClassificationRawPrediction() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");
StringIndexerModel stringIndexerModel = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(data);
data = stringIndexerModel.transform(data);
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
// Train a DecisionTree model.
DecisionTreeClassificationModel classificationModel = new DecisionTreeClassifier()
.setLabelCol("labelIndex")
.setFeaturesCol("features")
.setRawPredictionCol("rawPrediction")
.setPredictionCol("prediction")
.fit(trainingData);
byte[] exportedModel = ModelExporter.export(classificationModel, null);
Transformer transformer = (DecisionTreeTransformer) ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = classificationModel.transform(testData).select("features", "prediction", "rawPrediction").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector inp = (Vector) row.get(0);
double actual = row.getDouble(1);
double[] actualRaw = ((Vector) row.get(2)).toArray();
Map<String, Object> inputData = new HashMap<>();
inputData.put(transformer.getInputKeys().iterator().next(), inp.toArray());
transformer.transform(inputData);
double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());
double[] rawPrediction = (double[]) inputData.get("rawPrediction");
assertEquals(actual, predicted, EPSILON);
assertArrayEquals(actualRaw, rawPrediction, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:48,代码来源:DecisionTreeClassificationModelBridgeTest.java
示例9: testDecisionTreeClassificationWithPipeline
import org.apache.spark.sql.DataFrame; //导入方法依赖的package包/类
@Test
public void testDecisionTreeClassificationWithPipeline() {
// Load the data stored in LIBSVM format as a DataFrame.
DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");
// Split the data into training and test sets (30% held out for testing)
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
DataFrame trainingData = splits[0];
DataFrame testData = splits[1];
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex");
// Train a DecisionTree model.
DecisionTreeClassifier classificationModel = new DecisionTreeClassifier()
.setLabelCol("labelIndex")
.setFeaturesCol("features");
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{indexer, classificationModel});
// Train model. This also runs the indexer.
PipelineModel sparkPipeline = pipeline.fit(trainingData);
//Export this model
byte[] exportedModel = ModelExporter.export(sparkPipeline, null);
//Import and get Transformer
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
Row[] sparkOutput = sparkPipeline.transform(testData).select("label", "features", "prediction").collect();
//compare predictions
for (Row row : sparkOutput) {
Vector v = (Vector) row.get(1);
double actual = row.getDouble(2);
Map<String, Object> inputData = new HashMap<String, Object>();
inputData.put("features", v.toArray());
inputData.put("label", row.get(0).toString());
transformer.transform(inputData);
double predicted = (double) inputData.get("prediction");
assertEquals(actual, predicted, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:48,代码来源:DecisionTreeClassificationModelBridgeTest.java