当前位置: 首页>>代码示例>>Java>>正文


Java Evaluation.crossValidateModel方法代码示例

本文整理汇总了Java中weka.classifiers.Evaluation.crossValidateModel方法的典型用法代码示例。如果您正苦于以下问题:Java Evaluation.crossValidateModel方法的具体用法?Java Evaluation.crossValidateModel怎么用?Java Evaluation.crossValidateModel使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在weka.classifiers.Evaluation的用法示例。


在下文中一共展示了Evaluation.crossValidateModel方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: useClassifier

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
/**
 * uses the meta-classifier
 */
protected static void useClassifier(Instances data) throws Exception {
    System.out.println("\n1. Meta-classfier");
    AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();
    CfsSubsetEval eval = new CfsSubsetEval();
    //GreedyStepwise search = new GreedyStepwise();
    GeneticSearch search = new GeneticSearch();
    //	search.setSearchBackwards(false);
    RandomForest base = new RandomForest();
    classifier.setClassifier(base);
    System.out.println("Set the classifier : " + base.toString());
    classifier.setEvaluator(eval);
    System.out.println("Set the evaluator : " + eval.toString());
    //	classifier.setSearch( search );
    System.out.println("Set the search : " + search.toString());
    Evaluation evaluation = new Evaluation(data);
    evaluation.crossValidateModel(classifier, data, 10, new Random(1));
    System.out.println(evaluation.toSummaryString());
}
 
开发者ID:ajaybhat,项目名称:Essay-Grading-System,代码行数:22,代码来源:AttributeSelectionRunner.java

示例2: call

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public RecursiveCallable call() throws Exception {
    Lab lab = model.getVariables().get(0).getLab();

    if (lab.getModels().contains(model) && model.getEvaluation() != null) {
        return null;
    }

    Instances wekaData = new Util().getRandomWekaData(model.getVariables());
    evaluation = new Evaluation(wekaData);

    ObjectMapper om = new ObjectMapper();
    boolean canSerialize = om.canSerialize(Model.class);
    canSerialize = om.canSerialize(Classifier.class);

    try {
        om.readValue(om.writeValueAsBytes(model), Model.class);
    } catch (IOException e) {
        //e.printStackTrace();
    }

    evaluation.crossValidateModel(model, wekaData, lab.getCvFolds(), new SecureRandom());
    model.setEvaluation(evaluation);
    lab.addModel(model);

    return new ModelSaverCallable(model);
}
 
开发者ID:williamClanton,项目名称:jbossBA,代码行数:27,代码来源:ModelEvaluation.java

示例3: evaluate

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static double[] evaluate(Classifier model) throws Exception {

		double results[] = new double[4];

		String[] labelFiles = new String[] { "churn", "appetency", "upselling" };

		double overallScore = 0.0;
		for (int i = 0; i < labelFiles.length; i++) {

			// Load data
			Instances train_data = loadData("data/orange_small_train.data",
											"data/orange_small_train_" + labelFiles[i]+ ".labels.txt");
			train_data = preProcessData(train_data);

			// cross-validate the data
			Evaluation eval = new Evaluation(train_data);
			eval.crossValidateModel(model, train_data, 5, new Random(1), new Object[] {});

			// Save results
			results[i] = eval.areaUnderROC(train_data.classAttribute()
					.indexOfValue("1"));
			overallScore += results[i];
			System.out.println(labelFiles[i] + "\t-->\t" +results[i]);
		}
		// Get average results over all three problems
		results[3] = overallScore / 3;
		return results;
	}
 
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:29,代码来源:KddCup.java

示例4: runSVMRegression

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static void runSVMRegression() throws Exception {
    BufferedReader br = null;
    int numFolds = 10;
    br = new BufferedReader(new FileReader("rawData.arff"));
    Instances trainData = new Instances(br);
    trainData.setClassIndex(trainData.numAttributes() - 1);
    br.close();

    WekaPackageManager.loadPackages(false, true, false);
    AbstractClassifier classifier = (AbstractClassifier) Class.forName(
            "weka.classifiers.functions.supportVector").newInstance();
    String options = ("-S 3 -V 10 -T 0");
    String[] optionsArray = options.split(" ");
    classifier.setOptions(optionsArray);
    classifier.buildClassifier(trainData);

    Evaluation evaluation = new Evaluation(trainData);
    /*******************CROSS VALIDATION*************************/
    evaluation.crossValidateModel(classifier, trainData, numFolds, new Random(1));
    /***********************************************************/

    evaluateResults(evaluation);




}
 
开发者ID:gizemsogancioglu,项目名称:biosses,代码行数:28,代码来源:svmRegressor.java

示例5: getErrorPercent

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
@Override
public double getErrorPercent() {
    try {
        Evaluation eval = new Evaluation(getInstances());

        eval.crossValidateModel(getClassifier(), getInstances(),
                getFolds(), new Random()
        );

        return eval.pctIncorrect();
    } catch (Exception e) {
        e.printStackTrace();
        return 100;
    }
}
 
开发者ID:garciparedes,项目名称:java-examples,代码行数:16,代码来源:CrossValidation.java

示例6: trainClassifier

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public void trainClassifier(Classifier classifier, File trainingDataset,
                            FileOutputStream trainingModel, Integer
                                    crossValidationFoldNumber) throws Exception {

    CSVLoader csvLoader = new CSVLoader();
    csvLoader.setSource(trainingDataset);

    Instances instances = csvLoader.getDataSet();

    switch(classifier) {
        case KNN:
            int K = (int) Math.ceil(Math.sqrt(instances.numInstances()));
            this.classifier = new IBk(K);
            break;
        case NB:
            this.classifier = new NaiveBayes();
    }

    if(instances.classIndex() == -1) {
        instances.setClassIndex(instances.numAttributes() - 1);
    }

    this.classifier.buildClassifier(instances);

    if(crossValidationFoldNumber > 0) {
        Evaluation evaluation = new Evaluation(instances);
        evaluation.crossValidateModel(this.classifier, instances, crossValidationFoldNumber,
                new Random(1));
        kappa = evaluation.kappa();
        fMeasure = evaluation.weightedFMeasure();
        confusionMatrix = evaluation.toMatrixString("Confusion matrix: ");
    }

    ObjectOutputStream outputStream = new ObjectOutputStream(trainingModel);
    outputStream.writeObject(this.classifier);
    outputStream.flush();
    outputStream.close();
}
 
开发者ID:FlorentinTh,项目名称:SpeakerAuthentication,代码行数:39,代码来源:Learning.java

示例7: main

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
/**
 * @param args the command line arguments
 * @throws java.lang.Exception
 */
public static void main(String[] args) throws Exception {
    DataSource source = new DataSource("src/files/letter.arff");

    int folds = 10;
    int runs = 30;

    Classifier cls = new NaiveBayes();
    Instances data = source.getDataSet();
    data.setClassIndex(16);

    System.out.println("#seed \t correctly instances \t percentage of corrects\n");
    for (int i = 1; i <= runs; i++) {
        Evaluation eval = new Evaluation(data);
        eval.crossValidateModel(cls, data, folds, new Random(i));

        System.out.println("#" + i + "\t" + summary(eval));
    }
}
 
开发者ID:Unisep,项目名称:weka-algorithms,代码行数:23,代码来源:WekaAlgorithms.java

示例8: trainLibSvm

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static void trainLibSvm(final Instances trainingSet) throws Exception {
    // Create a classifier
    final LibSVM tree = new LibSVM();
    tree.buildClassifier(trainingSet);

    // Test the model
    final Evaluation eval = new Evaluation(trainingSet);
    eval.crossValidateModel(tree, trainingSet, 10, new Random(1));
    //eval.evaluateModel(tree, trainingSet);

    // Print the result à la Weka explorer:
    logger.info(eval.toSummaryString());

}
 
开发者ID:cobr123,项目名称:VirtaMarketAnalyzer,代码行数:15,代码来源:RetailSalePrediction.java

示例9: trainJ48BySet

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static void trainJ48BySet(final Instances trainingSet) throws Exception {
        // Create a classifier
        final J48 tree = new J48();
        tree.setMinNumObj(1);
        //tree.setConfidenceFactor(0.5f);
        tree.setReducedErrorPruning(true);
        //tree.setDebug(true);
        //
        tree.buildClassifier(trainingSet);
//        ClassifierToJs.saveModel(tree, GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_set_script.model");

        // Test the model
        final Evaluation eval = new Evaluation(trainingSet);
        eval.crossValidateModel(tree, trainingSet, 10, new Random(1));
        //eval.evaluateModel(tree, trainingSet);

        // Print the result à la Weka explorer:
        logger.info(eval.toSummaryString());
//        FileUtils.writeStringToFile(new File(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_set_summary.txt"), eval.toSummaryString());

//        try {
//            final File file = new File(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_set_script.js");
//            FileUtils.writeStringToFile(file, ClassifierToJs.compress(ClassifierToJs.toSource(tree, "predictCommonBySet")), "UTF-8");
//        } catch (final Exception e) {
//            logger.error(e.getLocalizedMessage(), e);
//        }
    }
 
开发者ID:cobr123,项目名称:VirtaMarketAnalyzer,代码行数:28,代码来源:RetailSalePrediction.java

示例10: trainJ48CrossValidation

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static void trainJ48CrossValidation(final Instances trainingSet) throws Exception {
        // Create a classifier
        final J48 tree = new J48();
        tree.setMinNumObj(1);
        //tree.setConfidenceFactor(0.5f);
        tree.setReducedErrorPruning(true);
//        tree.setDebug(true);

        //evaluate j48 with cross validation
        final Evaluation eval = new Evaluation(trainingSet);

        //first supply the classifier
        //then the training data
        //number of folds
        //random seed
        eval.crossValidateModel(tree, trainingSet, 10, new Random(new Date().getTime()));
        logger.info(eval.toSummaryString());
        Utils.writeFile(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_cv_summary.txt", eval.toSummaryString());

        tree.buildClassifier(trainingSet);
//                logger.info(tree.graph());

//        try {
//            final File file = new File(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_cv_script.js");
//            FileUtils.writeStringToFile(file, ClassifierToJs.compress(ClassifierToJs.toSource(tree, "predictCommonByCV")), "UTF-8");
//        } catch (final Exception e) {
//            logger.error(e.getLocalizedMessage(), e);
//        }
    }
 
开发者ID:cobr123,项目名称:VirtaMarketAnalyzer,代码行数:30,代码来源:RetailSalePrediction.java

示例11: ensembleVote

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static double ensembleVote(Instances train, Classifier[] newCfsArray) {

		double correctRate =0;
		
		try {
			int i;
			
			Vote ensemble = new Vote();
			SelectedTag tag = new SelectedTag(Vote.MAJORITY_VOTING_RULE,
					Vote.TAGS_RULES);
			ensemble.setCombinationRule(tag);
			ensemble.setClassifiers(newCfsArray);
			ensemble.setSeed(2);
			ensemble.buildClassifier(train);
			Evaluation eval = new Evaluation(train);
			Random random = new Random(1000);
			eval.crossValidateModel(ensemble, train, 5, random);
			
			correctRate = 1 - eval.errorRate();
//			setTempMatrixString(eval.toMatrixString());
//			setTempClassDetailsString(eval.toClassDetailsString());
		} catch (Exception e) {
			e.printStackTrace();
		}
		
		return correctRate;
	}
 
开发者ID:guojiasheng,项目名称:LibD3C-1.1,代码行数:28,代码来源:D3CVoter.java

示例12: main

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
static public void main(String[] args) {
	File folder = new File("../../lab1/reuters/");
	String[] files = folder.list(new FilenameFilter() {
		public boolean accept(File dir, String name) {
			return name.matches(".*\\.arff$");
		}
	});

	try {
		for (String f : files) {
			System.out.println(f);
			RNBtree r = new RNBtree("../../lab1/reuters/" + f);
			System.out.printf("numAttributes: %d\nnumClasses: %d\n", r.numAttribute, r.numClass);
			r.buildClassifier(r.data);
			Evaluation eval = new Evaluation(r.data);
			eval.crossValidateModel(r, r.data, 10, new Random(1));
			System.out.println(eval.toClassDetailsString());
			
			System.out.println("Unpruned tree");
			Evaluation e2 = new Evaluation(r.data);
			Classifier cls = new J48();
			cls.setOptions(new String[] {"-U", "true"});
			e2.crossValidateModel(cls, r.data, 10, new Random(1));
			System.out.println(e2.toClassDetailsString());
			
			System.out.println("Pruned tree");
			Evaluation e3 = new Evaluation(r.data);
			Classifier cls2 = new J48();
			cls2.setOptions(new String[] {"-C", "0.25"});
			e3.crossValidateModel(cls2, r.data, 10, new Random(1));
			System.out.println(e3.toClassDetailsString());
		}
	} catch (Exception e) {
		e.printStackTrace();
	}
}
 
开发者ID:thekingofkings,项目名称:RNBL-MN,代码行数:37,代码来源:RNBtree.java

示例13: crossValidate

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
/**
 * Perform cross validation evaluation of the classifier with the given number of folds.
 * @param foldNum
 * @throws Exception
 */
public void crossValidate(int foldNum) throws Exception
{
	System.out.println("WekaWrapper: "+foldNum+"-fold cross validation over train data.");
	System.err.println("WekaWrapper: "+foldNum+"-fold cross validation over train data.");
	Evaluation eTest = new Evaluation(traindata);				
	eTest.crossValidateModel(this.MLclass, traindata, foldNum, new Random(1));	//seed = 1;		
	/* it remains for the future to inspect the random generation. 
	 * It seems using the same seed over an specific sequence generates the same randomization. 
	 * Thus, for the same sequence of instances, fold generation is always the same.  
	 */
	//eTest.crossValidateModel(this.MLclass, traindata, foldNum, new Random((int)(Math.random()*traindata.numInstances())));
	printClassifierResults (eTest);
}
 
开发者ID:Elhuyar,项目名称:Elixa,代码行数:19,代码来源:WekaWrapper.java

示例14: getClassifierFScore

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static double getClassifierFScore(int numTopics, String categoryName) throws Exception
	{
		//int numTopics=40;
		int seed  = 1;
		int folds = 10;
		DataSource trainSource = new DataSource("inputFiles/rawFiles/ARFF-files/"+categoryName+"-ARFF/"
						+categoryName+"-"+numTopics+".ARFF");
		Instances trainingSet = trainSource.getDataSet();
		if (trainingSet.classIndex() == -1)
			trainingSet.setClassIndex(trainingSet.numAttributes() - 1);

		// Resample for minority class
		Resample reSample=new Resample();
		reSample.setInputFormat(trainingSet);
		//reSample.s(1);
		trainingSet=Filter.useFilter(trainingSet, reSample);
//		trainingSet=Filter.useFilter(trainingSet, reSample);
//		trainingSet=Filter.useFilter(trainingSet, reSample);
//		trainingSet=Filter.useFilter(trainingSet, reSample);
		Random rand = new Random(seed);
		trainingSet.randomize(rand);
		if (trainingSet.classAttribute().isNominal())
			trainingSet.stratify(folds);

		RandomForest classifier=new RandomForest();

		//System.out.println("Training with "+classifier.getClass().getName());
		//System.out.println(trainingSet.numInstances());
		//classifier.buildClassifier(trainingSet);
		// perform cross-validation
		//Object[] obj={"hello"};
		Evaluation eval = new Evaluation(trainingSet);
		//Object[] forPredictionsPrinting = {"a","10","true"};
		eval.crossValidateModel(classifier, trainingSet, 10, new Random(1), new Object[] { });
		return eval.weightedFMeasure();
	}
 
开发者ID:siddBanPsu,项目名称:WikiKreator,代码行数:37,代码来源:TestClassifierPerformance.java

示例15: debugEvaluateClassifier

import weka.classifiers.Evaluation; //导入方法依赖的package包/类
/**
 * Evaluates a classifier using 5-fold cross validation and returns the
 * evaluation object. Use this method for debugging purpose to get
 * information about precision, recall, etc.
 */
public Evaluation debugEvaluateClassifier() throws Exception, IOException {
	Instances data = wekaDataSetCreator.getDataSet();
	Evaluation eval = new Evaluation(data);
	eval.crossValidateModel(wekaClassifier, data, 5, new Random(1));
	return eval;
}
 
开发者ID:vimaier,项目名称:conqat,代码行数:12,代码来源:BaseWekaClassifier.java


注:本文中的weka.classifiers.Evaluation.crossValidateModel方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。