当前位置: 首页>>代码示例>>Java>>正文


Java RandomForest.buildClassifier方法代码示例

本文整理汇总了Java中weka.classifiers.trees.RandomForest.buildClassifier方法的典型用法代码示例。如果您正苦于以下问题:Java RandomForest.buildClassifier方法的具体用法?Java RandomForest.buildClassifier怎么用?Java RandomForest.buildClassifier使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在weka.classifiers.trees.RandomForest的用法示例。


在下文中一共展示了RandomForest.buildClassifier方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: trainRandomForest

import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
public static void trainRandomForest(final Instances trainingSet) throws Exception {
        // Create a classifier
        final RandomForest tree = new RandomForest();
        tree.buildClassifier(trainingSet);

        // Test the model
        final Evaluation eval = new Evaluation(trainingSet);
//        eval.crossValidateModel(tree, trainingSet, 10, new Random(1));
        eval.evaluateModel(tree, trainingSet);

        // Print the result à la Weka explorer:
        logger.info(eval.toSummaryString());
        logger.info(eval.toMatrixString());
        logger.info(tree.toString());
    }
 
开发者ID:cobr123,项目名称:VirtaMarketAnalyzer,代码行数:16,代码来源:RetailSalePrediction.java

示例2: train

import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
@Override
public RandomForest train(Instances instances) {
    RandomForest randomForest = new RandomForest();
    randomForest.setNumTrees(numTrees);
    try {
        randomForest.buildClassifier(instances);
    } catch (Exception e) {
        throw new ClassifierBuildingException("Exception occured while building classifier: " + e.getMessage(), e);
    }
    return randomForest;
}
 
开发者ID:NLeSC,项目名称:eEcology-Classification,代码行数:12,代码来源:RandomForestTrainer.java

示例3: trainModel

import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
/**
 * Train a model and save to filesystme
 *
 * @param trainArffFileName
 */
private static void trainModel(String trainArffFileName) {
    try {
        Instances structure = new Instances(new FileReader(new File(System.getProperty("user.dir") + "/data/Arffs/" + trainArffFileName + ".arff")));
        structure.setClassIndex(structure.numAttributes() - 1);
        System.out.println("Loaded data from arff file...");

        RandomForest randomForest = new RandomForest();
        randomForest.setNumFeatures(30);
        randomForest.setNumTrees(1000);

        System.out.println("Training...");
        randomForest.buildClassifier(structure);

        System.out.println("Saving trained model to '" + trainArffFileName + "'.");

        // Write trained model to file
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(new File(System.getProperty("user.dir") + "/data/Models/" + trainArffFileName + ".model")));
        objectOutputStream.writeObject(randomForest);
        objectOutputStream.flush();
        objectOutputStream.close();

    } catch (Exception e) {
        e.printStackTrace();
    }

}
 
开发者ID:ajaybhat,项目名称:Essay-Grading-System,代码行数:32,代码来源:Classifier.java

示例4: trainClassifier

import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
public RandomForest trainClassifier() throws Exception
	{
		int numTopics=40;
		int seed  = 1;
		int folds = 10;
		DataSource trainSource = new DataSource(baseDir+"inputFiles/ARFFFiles/Sections_"+numTopics+".ARFF");
		Instances trainingSet = trainSource.getDataSet();
		if (trainingSet.classIndex() == -1)
			trainingSet.setClassIndex(trainingSet.numAttributes() - 1);

		// Resample for minority class
		Resample reSample=new Resample();
		reSample.setInputFormat(trainingSet);
		//reSample.s(1);
		trainingSet=Filter.useFilter(trainingSet, reSample);
//		trainingSet=Filter.useFilter(trainingSet, reSample);
//		trainingSet=Filter.useFilter(trainingSet, reSample);
//		trainingSet=Filter.useFilter(trainingSet, reSample);
		Random rand = new Random(seed);
		trainingSet.randomize(rand);
		if (trainingSet.classAttribute().isNominal())
			trainingSet.stratify(folds);

		RandomForest classifier=new RandomForest();

		System.out.println("Training with "+classifier.getClass().getName());
		System.out.println(trainingSet.numInstances());
		//classifier.buildClassifier(trainingSet);
		// perform cross-validation
		//Object[] obj={"hello"};
		Evaluation eval = new Evaluation(trainingSet);
		//Object[] forPredictionsPrinting = {"a","10","true"};
		eval.crossValidateModel(classifier, trainingSet, 10, new Random(1), new Object[] { });
		//eval.crossValidateModel(classifier, trainingSet, 10, new Random(1));
		//		for (int n = 0; n < folds; n++) {
		//			System.out.println("Running fold:"+n);
		//			Instances train = trainingSet.trainCV(folds, n);
		//			Instances test = trainingSet.testCV(folds, n);
		//
		//			// build and evaluate classifier
		//			classifier = (RandomForest) Classifier.makeCopy(classifier);
		//			classifier.buildClassifier(train);
		//			eval.evaluateModel(classifier, test);
		//			
		//		}

		System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===\n", false));
		System.out.println(eval.toClassDetailsString()+"\n"+eval.toMatrixString()+"\n");
		classifier.buildClassifier(trainingSet);
		return classifier;

	}
 
开发者ID:siddBanPsu,项目名称:WikiKreator,代码行数:53,代码来源:PassageClassifier.java

示例5: getRFBestClassifier

import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
public static RandomForest getRFBestClassifier(int numTopics, String categoryName) throws Exception
	{
		
		int seed  = 1;
		int folds = 10;
		DataSource trainSource = new DataSource("inputFiles/rawFiles/ARFF-files/"+categoryName+"-ARFF/"
				+categoryName+"-"+numTopics+".ARFF");
		Instances trainingSet = trainSource.getDataSet();
		if (trainingSet.classIndex() == -1)
			trainingSet.setClassIndex(trainingSet.numAttributes() - 1);

		// Resample for minority class
		Resample reSample=new Resample();
		reSample.setInputFormat(trainingSet);
		//reSample.s(1);
		trainingSet=Filter.useFilter(trainingSet, reSample);
//		trainingSet=Filter.useFilter(trainingSet, reSample);
//		trainingSet=Filter.useFilter(trainingSet, reSample);
//		trainingSet=Filter.useFilter(trainingSet, reSample);
		Random rand = new Random(seed);
		trainingSet.randomize(rand);
		if (trainingSet.classAttribute().isNominal())
			trainingSet.stratify(folds);

		RandomForest classifier=new RandomForest();

		//System.out.println("Training with "+classifier.getClass().getName());
		//System.out.println(trainingSet.numInstances());
		//classifier.buildClassifier(trainingSet);
		// perform cross-validation
		//Object[] obj={"hello"};
		Evaluation eval = new Evaluation(trainingSet);
		//Object[] forPredictionsPrinting = {"a","10","true"};
		eval.crossValidateModel(classifier, trainingSet, 10, new Random(1), new Object[] { });
		//eval.crossValidateModel(classifier, trainingSet, 10, new Random(1));
		//		for (int n = 0; n < folds; n++) {
		//			System.out.println("Running fold:"+n);
		//			Instances train = trainingSet.trainCV(folds, n);
		//			Instances test = trainingSet.testCV(folds, n);
		//
		//			// build and evaluate classifier
		//			classifier = (RandomForest) Classifier.makeCopy(classifier);
		//			classifier.buildClassifier(train);
		//			eval.evaluateModel(classifier, test);
		//			
		//		}
		File f=new File("inputFiles/rawFiles/ARFF-files/"+categoryName+
				"-ARFF/"+categoryName+"-"+numTopics+".txt");
		BufferedWriter bw2=new BufferedWriter(new OutputStreamWriter
				(new FileOutputStream(f.getAbsolutePath()),"UTF-8"));
		bw2.write(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===\n", false));
		bw2.write(eval.toClassDetailsString()+"\n"+eval.toMatrixString()+"\n");
		bw2.close();
		classifier.buildClassifier(trainingSet);
		System.err.println("Loaded classifier.....");
		return classifier;

	}
 
开发者ID:siddBanPsu,项目名称:WikiKreator,代码行数:59,代码来源:PassageClassifier.java


注:本文中的weka.classifiers.trees.RandomForest.buildClassifier方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。