本文整理汇总了Java中weka.classifiers.trees.RandomForest.buildClassifier方法的典型用法代码示例。如果您正苦于以下问题:Java RandomForest.buildClassifier方法的具体用法?Java RandomForest.buildClassifier怎么用?Java RandomForest.buildClassifier使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类weka.classifiers.trees.RandomForest
的用法示例。
在下文中一共展示了RandomForest.buildClassifier方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: trainRandomForest
import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
public static void trainRandomForest(final Instances trainingSet) throws Exception {
// Create a classifier
final RandomForest tree = new RandomForest();
tree.buildClassifier(trainingSet);
// Test the model
final Evaluation eval = new Evaluation(trainingSet);
// eval.crossValidateModel(tree, trainingSet, 10, new Random(1));
eval.evaluateModel(tree, trainingSet);
// Print the result à la Weka explorer:
logger.info(eval.toSummaryString());
logger.info(eval.toMatrixString());
logger.info(tree.toString());
}
示例2: train
import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
@Override
public RandomForest train(Instances instances) {
RandomForest randomForest = new RandomForest();
randomForest.setNumTrees(numTrees);
try {
randomForest.buildClassifier(instances);
} catch (Exception e) {
throw new ClassifierBuildingException("Exception occured while building classifier: " + e.getMessage(), e);
}
return randomForest;
}
示例3: trainModel
import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
/**
* Train a model and save to filesystme
*
* @param trainArffFileName
*/
private static void trainModel(String trainArffFileName) {
try {
Instances structure = new Instances(new FileReader(new File(System.getProperty("user.dir") + "/data/Arffs/" + trainArffFileName + ".arff")));
structure.setClassIndex(structure.numAttributes() - 1);
System.out.println("Loaded data from arff file...");
RandomForest randomForest = new RandomForest();
randomForest.setNumFeatures(30);
randomForest.setNumTrees(1000);
System.out.println("Training...");
randomForest.buildClassifier(structure);
System.out.println("Saving trained model to '" + trainArffFileName + "'.");
// Write trained model to file
ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(new File(System.getProperty("user.dir") + "/data/Models/" + trainArffFileName + ".model")));
objectOutputStream.writeObject(randomForest);
objectOutputStream.flush();
objectOutputStream.close();
} catch (Exception e) {
e.printStackTrace();
}
}
示例4: trainClassifier
import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
public RandomForest trainClassifier() throws Exception
{
int numTopics=40;
int seed = 1;
int folds = 10;
DataSource trainSource = new DataSource(baseDir+"inputFiles/ARFFFiles/Sections_"+numTopics+".ARFF");
Instances trainingSet = trainSource.getDataSet();
if (trainingSet.classIndex() == -1)
trainingSet.setClassIndex(trainingSet.numAttributes() - 1);
// Resample for minority class
Resample reSample=new Resample();
reSample.setInputFormat(trainingSet);
//reSample.s(1);
trainingSet=Filter.useFilter(trainingSet, reSample);
// trainingSet=Filter.useFilter(trainingSet, reSample);
// trainingSet=Filter.useFilter(trainingSet, reSample);
// trainingSet=Filter.useFilter(trainingSet, reSample);
Random rand = new Random(seed);
trainingSet.randomize(rand);
if (trainingSet.classAttribute().isNominal())
trainingSet.stratify(folds);
RandomForest classifier=new RandomForest();
System.out.println("Training with "+classifier.getClass().getName());
System.out.println(trainingSet.numInstances());
//classifier.buildClassifier(trainingSet);
// perform cross-validation
//Object[] obj={"hello"};
Evaluation eval = new Evaluation(trainingSet);
//Object[] forPredictionsPrinting = {"a","10","true"};
eval.crossValidateModel(classifier, trainingSet, 10, new Random(1), new Object[] { });
//eval.crossValidateModel(classifier, trainingSet, 10, new Random(1));
// for (int n = 0; n < folds; n++) {
// System.out.println("Running fold:"+n);
// Instances train = trainingSet.trainCV(folds, n);
// Instances test = trainingSet.testCV(folds, n);
//
// // build and evaluate classifier
// classifier = (RandomForest) Classifier.makeCopy(classifier);
// classifier.buildClassifier(train);
// eval.evaluateModel(classifier, test);
//
// }
System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===\n", false));
System.out.println(eval.toClassDetailsString()+"\n"+eval.toMatrixString()+"\n");
classifier.buildClassifier(trainingSet);
return classifier;
}
示例5: getRFBestClassifier
import weka.classifiers.trees.RandomForest; //导入方法依赖的package包/类
public static RandomForest getRFBestClassifier(int numTopics, String categoryName) throws Exception
{
int seed = 1;
int folds = 10;
DataSource trainSource = new DataSource("inputFiles/rawFiles/ARFF-files/"+categoryName+"-ARFF/"
+categoryName+"-"+numTopics+".ARFF");
Instances trainingSet = trainSource.getDataSet();
if (trainingSet.classIndex() == -1)
trainingSet.setClassIndex(trainingSet.numAttributes() - 1);
// Resample for minority class
Resample reSample=new Resample();
reSample.setInputFormat(trainingSet);
//reSample.s(1);
trainingSet=Filter.useFilter(trainingSet, reSample);
// trainingSet=Filter.useFilter(trainingSet, reSample);
// trainingSet=Filter.useFilter(trainingSet, reSample);
// trainingSet=Filter.useFilter(trainingSet, reSample);
Random rand = new Random(seed);
trainingSet.randomize(rand);
if (trainingSet.classAttribute().isNominal())
trainingSet.stratify(folds);
RandomForest classifier=new RandomForest();
//System.out.println("Training with "+classifier.getClass().getName());
//System.out.println(trainingSet.numInstances());
//classifier.buildClassifier(trainingSet);
// perform cross-validation
//Object[] obj={"hello"};
Evaluation eval = new Evaluation(trainingSet);
//Object[] forPredictionsPrinting = {"a","10","true"};
eval.crossValidateModel(classifier, trainingSet, 10, new Random(1), new Object[] { });
//eval.crossValidateModel(classifier, trainingSet, 10, new Random(1));
// for (int n = 0; n < folds; n++) {
// System.out.println("Running fold:"+n);
// Instances train = trainingSet.trainCV(folds, n);
// Instances test = trainingSet.testCV(folds, n);
//
// // build and evaluate classifier
// classifier = (RandomForest) Classifier.makeCopy(classifier);
// classifier.buildClassifier(train);
// eval.evaluateModel(classifier, test);
//
// }
File f=new File("inputFiles/rawFiles/ARFF-files/"+categoryName+
"-ARFF/"+categoryName+"-"+numTopics+".txt");
BufferedWriter bw2=new BufferedWriter(new OutputStreamWriter
(new FileOutputStream(f.getAbsolutePath()),"UTF-8"));
bw2.write(eval.toSummaryString("=== " + folds + "-fold Cross-validation ===\n", false));
bw2.write(eval.toClassDetailsString()+"\n"+eval.toMatrixString()+"\n");
bw2.close();
classifier.buildClassifier(trainingSet);
System.err.println("Loaded classifier.....");
return classifier;
}