当前位置: 首页>>代码示例>>Java>>正文


Java Trial类代码示例

本文整理汇总了Java中cc.mallet.classify.Trial的典型用法代码示例。如果您正苦于以下问题:Java Trial类的具体用法?Java Trial怎么用?Java Trial使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


Trial类属于cc.mallet.classify包,在下文中一共展示了Trial类的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: AddClassifierTokenPredictions

import cc.mallet.classify.Trial; //导入依赖的package包/类
public AddClassifierTokenPredictions(TokenClassifiers tokenClassifiers, int[] predRanks2add, 
		boolean binary, InstanceList testList)
{
	m_predRanks2add = predRanks2add;
	m_binary = binary;
	m_tokenClassifiers = tokenClassifiers;
	m_inProduction = false;
	m_dataAlphabet = (Alphabet) tokenClassifiers.getAlphabet().clone();
	Alphabet labelAlphabet = tokenClassifiers.getLabelAlphabet();
	
	// add the token prediction features to the alphabet
	for (int i = 0; i < m_predRanks2add.length; i++) {
		for (int j = 0; j < labelAlphabet.size(); j++) {
			String featName = "TOK_PRED=" + labelAlphabet.lookupObject(j).toString() + "[email protected]_RANK_" + m_predRanks2add[i];
			m_dataAlphabet.lookupIndex(featName, true);
		}
	}
	
	// evaluate token classifier  
	if (testList != null) {
		Trial trial = new Trial(m_tokenClassifiers, testList);
		logger.info("Token classifier accuracy on test set = " + trial.getAccuracy());
	}
}
 
开发者ID:kostagiolasn,项目名称:NucleosomePatternClassifier,代码行数:25,代码来源:AddClassifierTokenPredictions.java

示例2: ConfusionMatrix

import cc.mallet.classify.Trial; //导入依赖的package包/类
/**
 * Constructs matrix and calculates values
 * @param t the trial to build matrix from
 */
public ConfusionMatrix(Trial t)
{
	this.trial = t;
	this.classifications = t;
	Labeling tempLabeling =
		((Classification)classifications.get(0)).getLabeling();
	this.numClasses = tempLabeling.getLabelAlphabet().size();
	values = new int[numClasses][numClasses];
	for(int i=0; i < classifications.size(); i++)
	{
		LabelVector lv =
			((Classification)classifications.get(i)).getLabelVector();
		Instance inst = ((Classification)classifications.get(i)).getInstance();
		int bestIndex = lv.getBestIndex();
		int correctIndex = inst.getLabeling().getBestIndex();
		assert(correctIndex != -1);
		//System.out.println("Best index="+bestIndex+". Correct="+correctIndex);
		values[correctIndex][bestIndex]++;
	}			
}
 
开发者ID:kostagiolasn,项目名称:NucleosomePatternClassifier,代码行数:25,代码来源:ConfusionMatrix.java

示例3: AccuracyCoverage

import cc.mallet.classify.Trial; //导入依赖的package包/类
/**
	 * Constructs object, sorts classifications, and creates
	 * accuracyValues array
     * @param t trial to get data from
     * @param numBuckets number of x-axis measurements to find accuracy
     */
	public AccuracyCoverage(Trial t, int numBuckets, String title, String dataName)
	{
		this.classifications = t;
		this.numBuckets = numBuckets;
		this.step = (double)DEFAULT_MAX_X/numBuckets;
		this.accuracyValues = new double[numBuckets];
		this.frame = null;
		logger.info("Constructing AccCov with " + 
											 this.classifications.size()); 
		sortClassifications();
/*		for(int i=0; i<classifications.size(); i++)
		{
			Classification c = (Classification)this.classifications.get(i);
			LabelVector distr = c.getLabelVector();
			System.out.println(distr.getBestValue());
		}
*/
		createAccuracyArray();
		this.graph = new Graph2(
			title, 0, 100,
			"Coverage", "Accuracy");
		addDataToGraph(this.accuracyValues, numBuckets, dataName);
	}
 
开发者ID:kostagiolasn,项目名称:NucleosomePatternClassifier,代码行数:30,代码来源:AccuracyCoverage.java

示例4: ConfusionMatrix

import cc.mallet.classify.Trial; //导入依赖的package包/类
/**
 * Constructs matrix and calculates values
 * @param t the trial to build matrix from
 */
public ConfusionMatrix(Trial t) {
	this.trial = t;
	this.classifications = t;
	Labeling tempLabeling =
		((Classification) classifications.get(0)).getLabeling();
	this.numClasses = tempLabeling.getLabelAlphabet().size();
	values = new int[numClasses][numClasses];
	
	for (int i=0; i < classifications.size(); i++) {
		LabelVector lv =
			((Classification)classifications.get(i)).getLabelVector();
		Instance inst = ((Classification)classifications.get(i)).getInstance();
		int bestIndex = lv.getBestIndex();
		int correctIndex = inst.getLabeling().getBestIndex();
		assert(correctIndex != -1);
		//System.out.println("Best index="+bestIndex+". Correct="+correctIndex);
		
		values[correctIndex][bestIndex]++;
	}			
}
 
开发者ID:iamxiatian,项目名称:wikit,代码行数:25,代码来源:ConfusionMatrix.java

示例5: doTraining

import cc.mallet.classify.Trial; //导入依赖的package包/类
private void doTraining(InstanceList trainList)
{
	// train a classifier on the entire training set
	logger.info("Training token classifier on entire data set (size=" + trainList.size() + ")...");
	m_tokenClassifier = m_trainer.train(trainList);

	Trial t = new Trial(m_tokenClassifier, trainList);
	logger.info("Training set accuracy = " + t.getAccuracy());
	
	if (m_numCV == 0)
		return;

	// train classifiers using cross validation
	InstanceList.CrossValidationIterator cvIter = trainList.new CrossValidationIterator(m_numCV, m_randSeed);
	int f = 1;

	while (cvIter.hasNext()) {
		f++;
		InstanceList[] fold = cvIter.nextSplit();

		logger.info("Training token classifier on cv fold " + f + " / " + m_numCV + " (size=" + fold[0].size() + ")...");
		
		Classifier foldClassifier = m_trainer.train(fold[0]);
		Trial t1 = new Trial(foldClassifier, fold[0]);
		Trial t2 = new Trial(foldClassifier, fold[1]);

		logger.info("Within-fold accuracy = " + t1.getAccuracy());
		logger.info("Out-of-fold accuracy = " + t2.getAccuracy());

		/*for (int x = 0; x < t2.size(); x++) {
			logger.info("xxx pred:" + t2.getClassification(x).getLabeling().getBestLabel() + " true:" + t2.getClassification(x).getInstance().getLabeling());
		}*/
		
		for (int i = 0; i < fold[1].size(); i++) {
			Instance inst = fold[1].get(i);
			m_table.put(inst.getName(), foldClassifier);
		}
	}
}
 
开发者ID:kostagiolasn,项目名称:NucleosomePatternClassifier,代码行数:40,代码来源:AddClassifierTokenPredictions.java

示例6: printTrial

import cc.mallet.classify.Trial; //导入依赖的package包/类
public void printTrial(Trial trial) {
    System.out.println("Accuracy(Micro): " + trial.getAccuracy());
    trial.getAverageRank();
    LabelAlphabet labelAlphabet = trial.getClassifier().getLabelAlphabet();
    double macro = 0;
    for (int i = 0; i < labelAlphabet.size(); i++) {
        System.out.println("F1 for class '" +
                labelAlphabet.lookupLabel(i) + "': " +
                trial.getF1(i));
        macro += trial.getF1(i);
    }
    System.out.println("Macro:" + macro / labelAlphabet.size());
}
 
开发者ID:iamxiatian,项目名称:wikit,代码行数:14,代码来源:EspmClassify.java

示例7: testRandomTrainedOn

import cc.mallet.classify.Trial; //导入依赖的package包/类
private double testRandomTrainedOn (InstanceList training)
{
  ClassifierTrainer trainer = new MaxEntTrainer();

  Alphabet fd = dictOfSize (3);
  String[] classNames = new String[] {"class0", "class1", "class2"};

  Randoms r = new Randoms (1);
  Iterator<Instance> iter = new RandomTokenSequenceIterator (r,  new Dirichlet(fd, 2.0),
        30, 0, 10, 200, classNames);
  training.addThruPipe (iter);

  InstanceList testing = new InstanceList (training.getPipe ());
  testing.addThruPipe (new RandomTokenSequenceIterator (r,  new Dirichlet(fd, 2.0),
        30, 0, 10, 200, classNames));

  System.out.println ("Training set size = "+training.size());
  System.out.println ("Testing set size = "+testing.size());

  Classifier classifier = trainer.train (training);

  System.out.println ("Accuracy on training set:");
  System.out.println (classifier.getClass().getName()
                        + ": " + new Trial(classifier, training).getAccuracy());

  System.out.println ("Accuracy on testing set:");
  double testAcc = new Trial (classifier, testing).getAccuracy();
  System.out.println (classifier.getClass().getName()
                        + ": " + testAcc);

  return testAcc;
}
 
开发者ID:mimno,项目名称:Mallet,代码行数:33,代码来源:TestPagedInstanceList.java

示例8: testTrainSplit

import cc.mallet.classify.Trial; //导入依赖的package包/类
public static Trial testTrainSplit(InstanceList instances) {

        InstanceList[] instanceLists = instances.split(new Randoms(),
                new double[] { 0.9, 0.1, 0.0 });

        // LOG.debug("{} training instance, {} testing instances",
        // instanceLists[0].size(), instanceLists[1].size());

        @SuppressWarnings("rawtypes")
        ClassifierTrainer trainer = new MaxEntTrainer();
        Classifier classifier = trainer.train(instanceLists[TRAINING]);
        return new Trial(classifier, instanceLists[TESTING]);
    }
 
开发者ID:BlueBrain,项目名称:bluima,代码行数:14,代码来源:ReferencesClassifierTrainer.java

示例9: trainClassifier

import cc.mallet.classify.Trial; //导入依赖的package包/类
public Classifier trainClassifier(InstanceList trainingInstances){
	
	System.out.println("train classifier start...");
	//ClassifierTrainer trainer = new MaxEntTrainer();
	//ClassifierTrainer trainer = new NaiveBayesTrainer();
	//ClassifierTrainer trainer = new C45Trainer();
	ClassifierTrainer trainer = new SVMClassifierTrainer(new LinearKernel(),false);
	
	
	Classifier temp = trainer.train(trainingInstances);
	Trial trial = new Trial(temp, trainingInstances);
	ConfusionMatrix matrix = new ConfusionMatrix(trial);
	double rho = ((SVMClassifier)temp).getRho()[0];
	System.out.println("rho:"+rho);
	System.out.println(matrix.toString());
    
    System.out.println("F1 for class 'Yes': " + trial.getF1("Yes"));

    System.out.println("Precision for class '" +
                       temp.getLabelAlphabet().lookupLabel(0) + "': " +
                       trial.getPrecision("Yes"));
    System.out.println("F1 for class 'No': " + trial.getF1("No"));

    System.out.println("Precision for class '" +
                       temp.getLabelAlphabet().lookupLabel(1) + "': " +
                       trial.getPrecision("No"));
    System.out.println("train classifer end...");
	return temp;
}
 
开发者ID:guanxin0520,项目名称:dhnowFilter,代码行数:30,代码来源:EvaluateClassifier.java

示例10: main

import cc.mallet.classify.Trial; //导入依赖的package包/类
public static void main(String[] args){
	
   	String stopListFilePath = "data/stoplists/en.txt";
   	String dataFolderPath = "data/ex6DataEmails/train";
   	String testFolderPath = "data/ex6DataEmails/test";
   	
	ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
	pipeList.add(new Input2CharSequence("UTF-8"));
	Pattern tokenPattern = Pattern.compile("[\\p{L}\\p{N}_]+");
	pipeList.add(new CharSequence2TokenSequence(tokenPattern));
	pipeList.add(new TokenSequenceLowercase());
	pipeList.add(new TokenSequenceRemoveStopwords(new File(stopListFilePath), "utf-8", false, false, false));
	pipeList.add(new TokenSequence2FeatureSequence());
	pipeList.add(new FeatureSequence2FeatureVector());
	pipeList.add(new Target2Label());
	SerialPipes pipeline = new SerialPipes(pipeList);
	
	FileIterator folderIterator = new FileIterator(
			new File[] {new File(dataFolderPath)},
	         new TxtFilter(),
	         FileIterator.LAST_DIRECTORY);

	
	InstanceList instances = new InstanceList(pipeline);
	
	instances.addThruPipe(folderIterator);
	
	ClassifierTrainer classifierTrainer = new NaiveBayesTrainer();
	Classifier classifier = classifierTrainer.train(instances);

	InstanceList testInstances = new InstanceList(classifier.getInstancePipe());
	folderIterator = new FileIterator(
			new File[] {new File(testFolderPath)},
	         new TxtFilter(),
	         FileIterator.LAST_DIRECTORY);
       testInstances.addThruPipe(folderIterator);
       
       Trial trial = new Trial(classifier, testInstances);
       
       System.out.println("Accuracy: " + trial.getAccuracy());
       System.out.println("F1 for class 'spam': " + trial.getF1("spam"));

       System.out.println("Precision for class '" +
                          classifier.getLabelAlphabet().lookupLabel(1) + "': " +
                          trial.getPrecision(1));

       System.out.println("Recall for class '" +
                          classifier.getLabelAlphabet().lookupLabel(1) + "': " +
                          trial.getRecall(1));

	
	

}
 
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:55,代码来源:SpamDetector.java

示例11: add

import cc.mallet.classify.Trial; //导入依赖的package包/类
/**
 * Adds trial results to the ROC data
 * 
 * @param trial Trial results to add to ROC data
 */
public void add(Trial trial) {
    for (Classification classification : trial) {
        add(classification);
    }
}
 
开发者ID:kostagiolasn,项目名称:NucleosomePatternClassifier,代码行数:11,代码来源:ROCData.java


注:本文中的cc.mallet.classify.Trial类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。