本文整理汇总了Java中cc.mallet.classify.Trial类的典型用法代码示例。如果您正苦于以下问题:Java Trial类的具体用法?Java Trial怎么用?Java Trial使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
Trial类属于cc.mallet.classify包,在下文中一共展示了Trial类的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: AddClassifierTokenPredictions
import cc.mallet.classify.Trial; //导入依赖的package包/类
public AddClassifierTokenPredictions(TokenClassifiers tokenClassifiers, int[] predRanks2add,
boolean binary, InstanceList testList)
{
m_predRanks2add = predRanks2add;
m_binary = binary;
m_tokenClassifiers = tokenClassifiers;
m_inProduction = false;
m_dataAlphabet = (Alphabet) tokenClassifiers.getAlphabet().clone();
Alphabet labelAlphabet = tokenClassifiers.getLabelAlphabet();
// add the token prediction features to the alphabet
for (int i = 0; i < m_predRanks2add.length; i++) {
for (int j = 0; j < labelAlphabet.size(); j++) {
String featName = "TOK_PRED=" + labelAlphabet.lookupObject(j).toString() + "[email protected]_RANK_" + m_predRanks2add[i];
m_dataAlphabet.lookupIndex(featName, true);
}
}
// evaluate token classifier
if (testList != null) {
Trial trial = new Trial(m_tokenClassifiers, testList);
logger.info("Token classifier accuracy on test set = " + trial.getAccuracy());
}
}
示例2: ConfusionMatrix
import cc.mallet.classify.Trial; //导入依赖的package包/类
/**
* Constructs matrix and calculates values
* @param t the trial to build matrix from
*/
public ConfusionMatrix(Trial t)
{
this.trial = t;
this.classifications = t;
Labeling tempLabeling =
((Classification)classifications.get(0)).getLabeling();
this.numClasses = tempLabeling.getLabelAlphabet().size();
values = new int[numClasses][numClasses];
for(int i=0; i < classifications.size(); i++)
{
LabelVector lv =
((Classification)classifications.get(i)).getLabelVector();
Instance inst = ((Classification)classifications.get(i)).getInstance();
int bestIndex = lv.getBestIndex();
int correctIndex = inst.getLabeling().getBestIndex();
assert(correctIndex != -1);
//System.out.println("Best index="+bestIndex+". Correct="+correctIndex);
values[correctIndex][bestIndex]++;
}
}
示例3: AccuracyCoverage
import cc.mallet.classify.Trial; //导入依赖的package包/类
/**
* Constructs object, sorts classifications, and creates
* accuracyValues array
* @param t trial to get data from
* @param numBuckets number of x-axis measurements to find accuracy
*/
public AccuracyCoverage(Trial t, int numBuckets, String title, String dataName)
{
this.classifications = t;
this.numBuckets = numBuckets;
this.step = (double)DEFAULT_MAX_X/numBuckets;
this.accuracyValues = new double[numBuckets];
this.frame = null;
logger.info("Constructing AccCov with " +
this.classifications.size());
sortClassifications();
/* for(int i=0; i<classifications.size(); i++)
{
Classification c = (Classification)this.classifications.get(i);
LabelVector distr = c.getLabelVector();
System.out.println(distr.getBestValue());
}
*/
createAccuracyArray();
this.graph = new Graph2(
title, 0, 100,
"Coverage", "Accuracy");
addDataToGraph(this.accuracyValues, numBuckets, dataName);
}
示例4: ConfusionMatrix
import cc.mallet.classify.Trial; //导入依赖的package包/类
/**
* Constructs matrix and calculates values
* @param t the trial to build matrix from
*/
public ConfusionMatrix(Trial t) {
this.trial = t;
this.classifications = t;
Labeling tempLabeling =
((Classification) classifications.get(0)).getLabeling();
this.numClasses = tempLabeling.getLabelAlphabet().size();
values = new int[numClasses][numClasses];
for (int i=0; i < classifications.size(); i++) {
LabelVector lv =
((Classification)classifications.get(i)).getLabelVector();
Instance inst = ((Classification)classifications.get(i)).getInstance();
int bestIndex = lv.getBestIndex();
int correctIndex = inst.getLabeling().getBestIndex();
assert(correctIndex != -1);
//System.out.println("Best index="+bestIndex+". Correct="+correctIndex);
values[correctIndex][bestIndex]++;
}
}
示例5: doTraining
import cc.mallet.classify.Trial; //导入依赖的package包/类
private void doTraining(InstanceList trainList)
{
// train a classifier on the entire training set
logger.info("Training token classifier on entire data set (size=" + trainList.size() + ")...");
m_tokenClassifier = m_trainer.train(trainList);
Trial t = new Trial(m_tokenClassifier, trainList);
logger.info("Training set accuracy = " + t.getAccuracy());
if (m_numCV == 0)
return;
// train classifiers using cross validation
InstanceList.CrossValidationIterator cvIter = trainList.new CrossValidationIterator(m_numCV, m_randSeed);
int f = 1;
while (cvIter.hasNext()) {
f++;
InstanceList[] fold = cvIter.nextSplit();
logger.info("Training token classifier on cv fold " + f + " / " + m_numCV + " (size=" + fold[0].size() + ")...");
Classifier foldClassifier = m_trainer.train(fold[0]);
Trial t1 = new Trial(foldClassifier, fold[0]);
Trial t2 = new Trial(foldClassifier, fold[1]);
logger.info("Within-fold accuracy = " + t1.getAccuracy());
logger.info("Out-of-fold accuracy = " + t2.getAccuracy());
/*for (int x = 0; x < t2.size(); x++) {
logger.info("xxx pred:" + t2.getClassification(x).getLabeling().getBestLabel() + " true:" + t2.getClassification(x).getInstance().getLabeling());
}*/
for (int i = 0; i < fold[1].size(); i++) {
Instance inst = fold[1].get(i);
m_table.put(inst.getName(), foldClassifier);
}
}
}
示例6: printTrial
import cc.mallet.classify.Trial; //导入依赖的package包/类
public void printTrial(Trial trial) {
System.out.println("Accuracy(Micro): " + trial.getAccuracy());
trial.getAverageRank();
LabelAlphabet labelAlphabet = trial.getClassifier().getLabelAlphabet();
double macro = 0;
for (int i = 0; i < labelAlphabet.size(); i++) {
System.out.println("F1 for class '" +
labelAlphabet.lookupLabel(i) + "': " +
trial.getF1(i));
macro += trial.getF1(i);
}
System.out.println("Macro:" + macro / labelAlphabet.size());
}
示例7: testRandomTrainedOn
import cc.mallet.classify.Trial; //导入依赖的package包/类
private double testRandomTrainedOn (InstanceList training)
{
ClassifierTrainer trainer = new MaxEntTrainer();
Alphabet fd = dictOfSize (3);
String[] classNames = new String[] {"class0", "class1", "class2"};
Randoms r = new Randoms (1);
Iterator<Instance> iter = new RandomTokenSequenceIterator (r, new Dirichlet(fd, 2.0),
30, 0, 10, 200, classNames);
training.addThruPipe (iter);
InstanceList testing = new InstanceList (training.getPipe ());
testing.addThruPipe (new RandomTokenSequenceIterator (r, new Dirichlet(fd, 2.0),
30, 0, 10, 200, classNames));
System.out.println ("Training set size = "+training.size());
System.out.println ("Testing set size = "+testing.size());
Classifier classifier = trainer.train (training);
System.out.println ("Accuracy on training set:");
System.out.println (classifier.getClass().getName()
+ ": " + new Trial(classifier, training).getAccuracy());
System.out.println ("Accuracy on testing set:");
double testAcc = new Trial (classifier, testing).getAccuracy();
System.out.println (classifier.getClass().getName()
+ ": " + testAcc);
return testAcc;
}
示例8: testTrainSplit
import cc.mallet.classify.Trial; //导入依赖的package包/类
public static Trial testTrainSplit(InstanceList instances) {
InstanceList[] instanceLists = instances.split(new Randoms(),
new double[] { 0.9, 0.1, 0.0 });
// LOG.debug("{} training instance, {} testing instances",
// instanceLists[0].size(), instanceLists[1].size());
@SuppressWarnings("rawtypes")
ClassifierTrainer trainer = new MaxEntTrainer();
Classifier classifier = trainer.train(instanceLists[TRAINING]);
return new Trial(classifier, instanceLists[TESTING]);
}
示例9: trainClassifier
import cc.mallet.classify.Trial; //导入依赖的package包/类
public Classifier trainClassifier(InstanceList trainingInstances){
System.out.println("train classifier start...");
//ClassifierTrainer trainer = new MaxEntTrainer();
//ClassifierTrainer trainer = new NaiveBayesTrainer();
//ClassifierTrainer trainer = new C45Trainer();
ClassifierTrainer trainer = new SVMClassifierTrainer(new LinearKernel(),false);
Classifier temp = trainer.train(trainingInstances);
Trial trial = new Trial(temp, trainingInstances);
ConfusionMatrix matrix = new ConfusionMatrix(trial);
double rho = ((SVMClassifier)temp).getRho()[0];
System.out.println("rho:"+rho);
System.out.println(matrix.toString());
System.out.println("F1 for class 'Yes': " + trial.getF1("Yes"));
System.out.println("Precision for class '" +
temp.getLabelAlphabet().lookupLabel(0) + "': " +
trial.getPrecision("Yes"));
System.out.println("F1 for class 'No': " + trial.getF1("No"));
System.out.println("Precision for class '" +
temp.getLabelAlphabet().lookupLabel(1) + "': " +
trial.getPrecision("No"));
System.out.println("train classifer end...");
return temp;
}
示例10: main
import cc.mallet.classify.Trial; //导入依赖的package包/类
public static void main(String[] args){
String stopListFilePath = "data/stoplists/en.txt";
String dataFolderPath = "data/ex6DataEmails/train";
String testFolderPath = "data/ex6DataEmails/test";
ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
pipeList.add(new Input2CharSequence("UTF-8"));
Pattern tokenPattern = Pattern.compile("[\\p{L}\\p{N}_]+");
pipeList.add(new CharSequence2TokenSequence(tokenPattern));
pipeList.add(new TokenSequenceLowercase());
pipeList.add(new TokenSequenceRemoveStopwords(new File(stopListFilePath), "utf-8", false, false, false));
pipeList.add(new TokenSequence2FeatureSequence());
pipeList.add(new FeatureSequence2FeatureVector());
pipeList.add(new Target2Label());
SerialPipes pipeline = new SerialPipes(pipeList);
FileIterator folderIterator = new FileIterator(
new File[] {new File(dataFolderPath)},
new TxtFilter(),
FileIterator.LAST_DIRECTORY);
InstanceList instances = new InstanceList(pipeline);
instances.addThruPipe(folderIterator);
ClassifierTrainer classifierTrainer = new NaiveBayesTrainer();
Classifier classifier = classifierTrainer.train(instances);
InstanceList testInstances = new InstanceList(classifier.getInstancePipe());
folderIterator = new FileIterator(
new File[] {new File(testFolderPath)},
new TxtFilter(),
FileIterator.LAST_DIRECTORY);
testInstances.addThruPipe(folderIterator);
Trial trial = new Trial(classifier, testInstances);
System.out.println("Accuracy: " + trial.getAccuracy());
System.out.println("F1 for class 'spam': " + trial.getF1("spam"));
System.out.println("Precision for class '" +
classifier.getLabelAlphabet().lookupLabel(1) + "': " +
trial.getPrecision(1));
System.out.println("Recall for class '" +
classifier.getLabelAlphabet().lookupLabel(1) + "': " +
trial.getRecall(1));
}
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:55,代码来源:SpamDetector.java
示例11: add
import cc.mallet.classify.Trial; //导入依赖的package包/类
/**
* Adds trial results to the ROC data
*
* @param trial Trial results to add to ROC data
*/
public void add(Trial trial) {
for (Classification classification : trial) {
add(classification);
}
}