本文整理汇总了Java中org.apache.spark.ml.tuning.CrossValidatorModel类的典型用法代码示例。如果您正苦于以下问题:Java CrossValidatorModel类的具体用法?Java CrossValidatorModel怎么用?Java CrossValidatorModel使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
CrossValidatorModel类属于org.apache.spark.ml.tuning包,在下文中一共展示了CrossValidatorModel类的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: trainWithCrossValidation
import org.apache.spark.ml.tuning.CrossValidatorModel; //导入依赖的package包/类
private static Transformer trainWithCrossValidation(DataFrame train, Pipeline pipeline, ParamMap[] paramGrid) {
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
// This will allow us to jointly choose parameters for all Pipeline stages.
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
// is areaUnderROC.
CrossValidator cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator())
.setEstimatorParamMaps(paramGrid)
.setNumFolds(10); // Use 3+ in practice
// Run cross-validation, and choose the best set of parameters.
CrossValidatorModel model = cv.fit(train);
return model;
}
示例2: predictRUL
import org.apache.spark.ml.tuning.CrossValidatorModel; //导入依赖的package包/类
/**
* Stored Procedure for Predictions
*/
public static void predictRUL(String sensorTableName,
String resultsTableName, String savedModelPath, int loopinterval) {
try {
//Initialize variables
if (sensorTableName == null || sensorTableName.length() == 0)
sensorTableName = "IOT.SENSOR_AGG_1_VIEW";
if (resultsTableName == null || resultsTableName.length() == 0)
resultsTableName = "IOT.PREDICTION_EXT";
if (savedModelPath == null || savedModelPath.length() == 0)
savedModelPath = "/tmp";
if (!savedModelPath.endsWith("/"))
savedModelPath = savedModelPath + "/";
savedModelPath += "model/";
String jdbcUrl = "jdbc:splice://localhost:1527/splicedb;user=splice;password=admin;useSpark=true";
Connection conn = DriverManager.getConnection(jdbcUrl);
SparkSession sparkSession = SpliceSpark.getSession();
//Specify the data for predictions
Map<String, String> options = new HashMap<String, String>();
options.put("driver", "com.splicemachine.db.jdbc.ClientDriver");
options.put("url", jdbcUrl);
options.put("dbtable", sensorTableName);
//Load Model to use for predictins
CrossValidatorModel cvModel = CrossValidatorModel
.load(savedModelPath);
//Keep checking for new data and make predictions
while (loopinterval > 0) {
//Sensor data requiring predictions
Dataset<Row> sensords = sparkSession.read().format("jdbc")
.options(options).load();
//prepare data
sensords = sensords.na().fill(0);
//make predictions
Dataset<Row> predictions = cvModel.transform(sensords)
.select("ENGINE_TYPE", "UNIT", "TIME", "prediction")
.withColumnRenamed("prediction", "PREDICTION");
//Save predictions
String fileName = "temp_pred_"
+ RandomStringUtils.randomAlphabetic(6).toLowerCase();
predictions.write().mode(SaveMode.Append)
.csv("/tmp/data_pred/predictions");
//Mark records for which predictions are made
PreparedStatement pStmtDel = conn
.prepareStatement("delete from IOT.TO_PROCESS_SENSOR s where exists (select 1 from IOT.PREDICTIONS_EXT p where p.engine_type = s.engine_type and p.unit= s.unit and p.time=s.time )");
pStmtDel.execute();
pStmtDel.close();
}
} catch (SQLException sqle) {
System.out.println("Error :::::" + sqle.toString());
LOG.error("Exception in getColumnStatistics", sqle);
sqle.printStackTrace();
}
}