本文整理汇总了Java中org.apache.mahout.classifier.evaluation.Auc类的典型用法代码示例。如果您正苦于以下问题:Java Auc类的具体用法?Java Auc怎么用?Java Auc使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
Auc类属于org.apache.mahout.classifier.evaluation包,在下文中一共展示了Auc类的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: main
import org.apache.mahout.classifier.evaluation.Auc; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
showAuc = true;
showConfusion = true;
Auc collector = new Auc();
LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile));
CsvRecordFactory csv = lmp.getCsvRecordFactory();
OnlineLogisticRegression lr = lmp.createRegression();
BufferedReader in = OnlineLogisticRegressionTest.open(inputFile);
String line = in.readLine();
csv.firstLine(line);
line = in.readLine();
PrintWriter output=new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true);
output.println("\"target\",\"model-output\",\"log-likelihood\"");
while (line != null) {
System.out.println("-----" + line);
Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
int target = csv.processLine(line, v);
double score = lr.classifyScalarNoLink(v);
output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v));
collector.add(target, score);
line = in.readLine();
System.out.println("I am here");
}
output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc());
Matrix m = collector.confusion();
output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n",
m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
m = collector.entropy();
output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n",
m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
}
开发者ID:PacktPublishing,项目名称:Java-Data-Science-Cookbook,代码行数:32,代码来源:OnlineLogisticRegressionTest.java
示例2: main
import org.apache.mahout.classifier.evaluation.Auc; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv"));
double heldOutPercentage = 0.10;
for (int run = 0; run < 20; run++) {
Collections.shuffle(calls);
int cutoff = (int) (heldOutPercentage * calls.size());
List<TelephoneCall> test = calls.subList(0, cutoff);
List<TelephoneCall> train = calls.subList(cutoff, calls.size());
OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES, new L1())
.learningRate(1)
.alpha(1)
.lambda(0.000001)
.stepOffset(10000)
.decayExponent(0.2);
for (int pass = 0; pass < 20; pass++) {
for (TelephoneCall observation : train) {
lr.train(observation.getTarget(), observation.asVector());
}
if (pass % 5 == 0) {
Auc eval = new Auc(0.5);
for (TelephoneCall testCall : test) {
eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector()));
}
System.out.printf("%d, %.4f, %.4f\n", pass, lr.currentLearningRate(), eval.auc());
}
}
}
}
开发者ID:frankscholten,项目名称:mahout-sgd-bank-marketing,代码行数:34,代码来源:TelephoneCallClassificationMain.java