本文整理汇总了Java中org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator类的典型用法代码示例。如果您正苦于以下问题:Java MulticlassClassificationEvaluator类的具体用法?Java MulticlassClassificationEvaluator怎么用?Java MulticlassClassificationEvaluator使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
MulticlassClassificationEvaluator类属于org.apache.spark.ml.evaluation包,在下文中一共展示了MulticlassClassificationEvaluator类的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: train
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; //导入依赖的package包/类
/**
* Trains a whitespace classifier model and save the resulting pipeline model
* to an external file.
* @param sentences a list of tokenized sentences.
* @param pipelineModelFileName
* @param numFeatures
*/
public void train(List<String> sentences, String pipelineModelFileName, int numFeatures) {
List<WhitespaceContext> contexts = new ArrayList<WhitespaceContext>(sentences.size());
int id = 0;
for (String sentence : sentences) {
sentence = sentence.trim();
for (int j = 0; j < sentence.length(); j++) {
char c = sentence.charAt(j);
if (c == ' ' || c == '_') {
WhitespaceContext context = new WhitespaceContext();
context.setId(id++);
context.setContext(extractContext(sentence, j));
context.setLabel(c == ' ' ? 0d : 1d);
contexts.add(context);
}
}
}
JavaRDD<WhitespaceContext> jrdd = jsc.parallelize(contexts);
DataFrame df = sqlContext.createDataFrame(jrdd, WhitespaceContext.class);
df.show(false);
System.out.println("N = " + df.count());
df.groupBy("label").count().show();
org.apache.spark.ml.feature.Tokenizer tokenizer = new Tokenizer()
.setInputCol("context").setOutputCol("words");
HashingTF hashingTF = new HashingTF().setNumFeatures(numFeatures)
.setInputCol(tokenizer.getOutputCol()).setOutputCol("features");
LogisticRegression lr = new LogisticRegression().setMaxIter(100)
.setRegParam(0.01);
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {
tokenizer, hashingTF, lr });
model = pipeline.fit(df);
try {
model.write().overwrite().save(pipelineModelFileName);
} catch (IOException e) {
e.printStackTrace();
}
DataFrame predictions = model.transform(df);
predictions.show();
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setMetricName("precision");
double accuracy = evaluator.evaluate(predictions);
System.out.println("training accuracy = " + accuracy);
LogisticRegressionModel lrModel = (LogisticRegressionModel) model.stages()[2];
LogisticRegressionTrainingSummary trainingSummary = lrModel.summary();
double[] objectiveHistory = trainingSummary.objectiveHistory();
System.out.println("#(iterations) = " + objectiveHistory.length);
for (double lossPerIteration : objectiveHistory) {
System.out.println(lossPerIteration);
}
}