本文整理汇总了Java中weka.classifiers.Evaluation.crossValidateModel方法的典型用法代码示例。如果您正苦于以下问题:Java Evaluation.crossValidateModel方法的具体用法?Java Evaluation.crossValidateModel怎么用?Java Evaluation.crossValidateModel使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类weka.classifiers.Evaluation
的用法示例。
在下文中一共展示了Evaluation.crossValidateModel方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: useClassifier
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
/**
* uses the meta-classifier
*/
protected static void useClassifier(Instances data) throws Exception {
System.out.println("\n1. Meta-classfier");
AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();
CfsSubsetEval eval = new CfsSubsetEval();
//GreedyStepwise search = new GreedyStepwise();
GeneticSearch search = new GeneticSearch();
// search.setSearchBackwards(false);
RandomForest base = new RandomForest();
classifier.setClassifier(base);
System.out.println("Set the classifier : " + base.toString());
classifier.setEvaluator(eval);
System.out.println("Set the evaluator : " + eval.toString());
// classifier.setSearch( search );
System.out.println("Set the search : " + search.toString());
Evaluation evaluation = new Evaluation(data);
evaluation.crossValidateModel(classifier, data, 10, new Random(1));
System.out.println(evaluation.toSummaryString());
}
示例2: call
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public RecursiveCallable call() throws Exception {
Lab lab = model.getVariables().get(0).getLab();
if (lab.getModels().contains(model) && model.getEvaluation() != null) {
return null;
}
Instances wekaData = new Util().getRandomWekaData(model.getVariables());
evaluation = new Evaluation(wekaData);
ObjectMapper om = new ObjectMapper();
boolean canSerialize = om.canSerialize(Model.class);
canSerialize = om.canSerialize(Classifier.class);
try {
om.readValue(om.writeValueAsBytes(model), Model.class);
} catch (IOException e) {
//e.printStackTrace();
}
evaluation.crossValidateModel(model, wekaData, lab.getCvFolds(), new SecureRandom());
model.setEvaluation(evaluation);
lab.addModel(model);
return new ModelSaverCallable(model);
}
示例3: evaluate
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static double[] evaluate(Classifier model) throws Exception {
double results[] = new double[4];
String[] labelFiles = new String[] { "churn", "appetency", "upselling" };
double overallScore = 0.0;
for (int i = 0; i < labelFiles.length; i++) {
// Load data
Instances train_data = loadData("data/orange_small_train.data",
"data/orange_small_train_" + labelFiles[i]+ ".labels.txt");
train_data = preProcessData(train_data);
// cross-validate the data
Evaluation eval = new Evaluation(train_data);
eval.crossValidateModel(model, train_data, 5, new Random(1), new Object[] {});
// Save results
results[i] = eval.areaUnderROC(train_data.classAttribute()
.indexOfValue("1"));
overallScore += results[i];
System.out.println(labelFiles[i] + "\t-->\t" +results[i]);
}
// Get average results over all three problems
results[3] = overallScore / 3;
return results;
}
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:29,代码来源:KddCup.java
示例4: runSVMRegression
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static void runSVMRegression() throws Exception {
BufferedReader br = null;
int numFolds = 10;
br = new BufferedReader(new FileReader("rawData.arff"));
Instances trainData = new Instances(br);
trainData.setClassIndex(trainData.numAttributes() - 1);
br.close();
WekaPackageManager.loadPackages(false, true, false);
AbstractClassifier classifier = (AbstractClassifier) Class.forName(
"weka.classifiers.functions.supportVector").newInstance();
String options = ("-S 3 -V 10 -T 0");
String[] optionsArray = options.split(" ");
classifier.setOptions(optionsArray);
classifier.buildClassifier(trainData);
Evaluation evaluation = new Evaluation(trainData);
/*******************CROSS VALIDATION*************************/
evaluation.crossValidateModel(classifier, trainData, numFolds, new Random(1));
/***********************************************************/
evaluateResults(evaluation);
}
示例5: getErrorPercent
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
@Override
public double getErrorPercent() {
try {
Evaluation eval = new Evaluation(getInstances());
eval.crossValidateModel(getClassifier(), getInstances(),
getFolds(), new Random()
);
return eval.pctIncorrect();
} catch (Exception e) {
e.printStackTrace();
return 100;
}
}
示例6: trainClassifier
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public void trainClassifier(Classifier classifier, File trainingDataset,
FileOutputStream trainingModel, Integer
crossValidationFoldNumber) throws Exception {
CSVLoader csvLoader = new CSVLoader();
csvLoader.setSource(trainingDataset);
Instances instances = csvLoader.getDataSet();
switch(classifier) {
case KNN:
int K = (int) Math.ceil(Math.sqrt(instances.numInstances()));
this.classifier = new IBk(K);
break;
case NB:
this.classifier = new NaiveBayes();
}
if(instances.classIndex() == -1) {
instances.setClassIndex(instances.numAttributes() - 1);
}
this.classifier.buildClassifier(instances);
if(crossValidationFoldNumber > 0) {
Evaluation evaluation = new Evaluation(instances);
evaluation.crossValidateModel(this.classifier, instances, crossValidationFoldNumber,
new Random(1));
kappa = evaluation.kappa();
fMeasure = evaluation.weightedFMeasure();
confusionMatrix = evaluation.toMatrixString("Confusion matrix: ");
}
ObjectOutputStream outputStream = new ObjectOutputStream(trainingModel);
outputStream.writeObject(this.classifier);
outputStream.flush();
outputStream.close();
}
示例7: main
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
/**
* @param args the command line arguments
* @throws java.lang.Exception
*/
public static void main(String[] args) throws Exception {
DataSource source = new DataSource("src/files/letter.arff");
int folds = 10;
int runs = 30;
Classifier cls = new NaiveBayes();
Instances data = source.getDataSet();
data.setClassIndex(16);
System.out.println("#seed \t correctly instances \t percentage of corrects\n");
for (int i = 1; i <= runs; i++) {
Evaluation eval = new Evaluation(data);
eval.crossValidateModel(cls, data, folds, new Random(i));
System.out.println("#" + i + "\t" + summary(eval));
}
}
示例8: trainLibSvm
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static void trainLibSvm(final Instances trainingSet) throws Exception {
// Create a classifier
final LibSVM tree = new LibSVM();
tree.buildClassifier(trainingSet);
// Test the model
final Evaluation eval = new Evaluation(trainingSet);
eval.crossValidateModel(tree, trainingSet, 10, new Random(1));
//eval.evaluateModel(tree, trainingSet);
// Print the result à la Weka explorer:
logger.info(eval.toSummaryString());
}
示例9: trainJ48BySet
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static void trainJ48BySet(final Instances trainingSet) throws Exception {
// Create a classifier
final J48 tree = new J48();
tree.setMinNumObj(1);
//tree.setConfidenceFactor(0.5f);
tree.setReducedErrorPruning(true);
//tree.setDebug(true);
//
tree.buildClassifier(trainingSet);
// ClassifierToJs.saveModel(tree, GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_set_script.model");
// Test the model
final Evaluation eval = new Evaluation(trainingSet);
eval.crossValidateModel(tree, trainingSet, 10, new Random(1));
//eval.evaluateModel(tree, trainingSet);
// Print the result à la Weka explorer:
logger.info(eval.toSummaryString());
// FileUtils.writeStringToFile(new File(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_set_summary.txt"), eval.toSummaryString());
// try {
// final File file = new File(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_set_script.js");
// FileUtils.writeStringToFile(file, ClassifierToJs.compress(ClassifierToJs.toSource(tree, "predictCommonBySet")), "UTF-8");
// } catch (final Exception e) {
// logger.error(e.getLocalizedMessage(), e);
// }
}
示例10: trainJ48CrossValidation
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static void trainJ48CrossValidation(final Instances trainingSet) throws Exception {
// Create a classifier
final J48 tree = new J48();
tree.setMinNumObj(1);
//tree.setConfidenceFactor(0.5f);
tree.setReducedErrorPruning(true);
// tree.setDebug(true);
//evaluate j48 with cross validation
final Evaluation eval = new Evaluation(trainingSet);
//first supply the classifier
//then the training data
//number of folds
//random seed
eval.crossValidateModel(tree, trainingSet, 10, new Random(new Date().getTime()));
logger.info(eval.toSummaryString());
Utils.writeFile(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_cv_summary.txt", eval.toSummaryString());
tree.buildClassifier(trainingSet);
// logger.info(tree.graph());
// try {
// final File file = new File(GitHubPublisher.localPath + RetailSalePrediction.predict_retail_sales + File.separator + "prediction_cv_script.js");
// FileUtils.writeStringToFile(file, ClassifierToJs.compress(ClassifierToJs.toSource(tree, "predictCommonByCV")), "UTF-8");
// } catch (final Exception e) {
// logger.error(e.getLocalizedMessage(), e);
// }
}
示例11: ensembleVote
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static double ensembleVote(Instances train, Classifier[] newCfsArray) {
double correctRate =0;
try {
int i;
Vote ensemble = new Vote();
SelectedTag tag = new SelectedTag(Vote.MAJORITY_VOTING_RULE,
Vote.TAGS_RULES);
ensemble.setCombinationRule(tag);
ensemble.setClassifiers(newCfsArray);
ensemble.setSeed(2);
ensemble.buildClassifier(train);
Evaluation eval = new Evaluation(train);
Random random = new Random(1000);
eval.crossValidateModel(ensemble, train, 5, random);
correctRate = 1 - eval.errorRate();
// setTempMatrixString(eval.toMatrixString());
// setTempClassDetailsString(eval.toClassDetailsString());
} catch (Exception e) {
e.printStackTrace();
}
return correctRate;
}
示例12: main
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
static public void main(String[] args) {
File folder = new File("../../lab1/reuters/");
String[] files = folder.list(new FilenameFilter() {
public boolean accept(File dir, String name) {
return name.matches(".*\\.arff$");
}
});
try {
for (String f : files) {
System.out.println(f);
RNBtree r = new RNBtree("../../lab1/reuters/" + f);
System.out.printf("numAttributes: %d\nnumClasses: %d\n", r.numAttribute, r.numClass);
r.buildClassifier(r.data);
Evaluation eval = new Evaluation(r.data);
eval.crossValidateModel(r, r.data, 10, new Random(1));
System.out.println(eval.toClassDetailsString());
System.out.println("Unpruned tree");
Evaluation e2 = new Evaluation(r.data);
Classifier cls = new J48();
cls.setOptions(new String[] {"-U", "true"});
e2.crossValidateModel(cls, r.data, 10, new Random(1));
System.out.println(e2.toClassDetailsString());
System.out.println("Pruned tree");
Evaluation e3 = new Evaluation(r.data);
Classifier cls2 = new J48();
cls2.setOptions(new String[] {"-C", "0.25"});
e3.crossValidateModel(cls2, r.data, 10, new Random(1));
System.out.println(e3.toClassDetailsString());
}
} catch (Exception e) {
e.printStackTrace();
}
}
示例13: crossValidate
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
/**
* Perform cross validation evaluation of the classifier with the given number of folds.
* @param foldNum
* @throws Exception
*/
public void crossValidate(int foldNum) throws Exception
{
System.out.println("WekaWrapper: "+foldNum+"-fold cross validation over train data.");
System.err.println("WekaWrapper: "+foldNum+"-fold cross validation over train data.");
Evaluation eTest = new Evaluation(traindata);
eTest.crossValidateModel(this.MLclass, traindata, foldNum, new Random(1)); //seed = 1;
/* it remains for the future to inspect the random generation.
* It seems using the same seed over an specific sequence generates the same randomization.
* Thus, for the same sequence of instances, fold generation is always the same.
*/
//eTest.crossValidateModel(this.MLclass, traindata, foldNum, new Random((int)(Math.random()*traindata.numInstances())));
printClassifierResults (eTest);
}
示例14: getClassifierFScore
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
public static double getClassifierFScore(int numTopics, String categoryName) throws Exception
{
//int numTopics=40;
int seed = 1;
int folds = 10;
DataSource trainSource = new DataSource("inputFiles/rawFiles/ARFF-files/"+categoryName+"-ARFF/"
+categoryName+"-"+numTopics+".ARFF");
Instances trainingSet = trainSource.getDataSet();
if (trainingSet.classIndex() == -1)
trainingSet.setClassIndex(trainingSet.numAttributes() - 1);
// Resample for minority class
Resample reSample=new Resample();
reSample.setInputFormat(trainingSet);
//reSample.s(1);
trainingSet=Filter.useFilter(trainingSet, reSample);
// trainingSet=Filter.useFilter(trainingSet, reSample);
// trainingSet=Filter.useFilter(trainingSet, reSample);
// trainingSet=Filter.useFilter(trainingSet, reSample);
Random rand = new Random(seed);
trainingSet.randomize(rand);
if (trainingSet.classAttribute().isNominal())
trainingSet.stratify(folds);
RandomForest classifier=new RandomForest();
//System.out.println("Training with "+classifier.getClass().getName());
//System.out.println(trainingSet.numInstances());
//classifier.buildClassifier(trainingSet);
// perform cross-validation
//Object[] obj={"hello"};
Evaluation eval = new Evaluation(trainingSet);
//Object[] forPredictionsPrinting = {"a","10","true"};
eval.crossValidateModel(classifier, trainingSet, 10, new Random(1), new Object[] { });
return eval.weightedFMeasure();
}
示例15: debugEvaluateClassifier
import weka.classifiers.Evaluation; //导入方法依赖的package包/类
/**
* Evaluates a classifier using 5-fold cross validation and returns the
* evaluation object. Use this method for debugging purpose to get
* information about precision, recall, etc.
*/
public Evaluation debugEvaluateClassifier() throws Exception, IOException {
Instances data = wekaDataSetCreator.getDataSet();
Evaluation eval = new Evaluation(data);
eval.crossValidateModel(wekaClassifier, data, 5, new Random(1));
return eval;
}