当前位置: 首页>>代码示例>>Java>>正文


Java CrossValidation类代码示例

本文整理汇总了Java中smile.validation.CrossValidation的典型用法代码示例。如果您正苦于以下问题:Java CrossValidation类的具体用法?Java CrossValidation怎么用?Java CrossValidation使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


CrossValidation类属于smile.validation包,在下文中一共展示了CrossValidation类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testWSJ

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class HMMPOSTagger.
 */
@Test
public void testWSJ() {
    System.out.println("WSJ");
    load("D:\\sourceforge\\corpora\\PennTreebank\\PennTreebank2\\TAGGED\\POS\\WSJ");
    
    String[][] x = sentences.toArray(new String[sentences.size()][]);
    PennTreebankPOS[][] y = labels.toArray(new PennTreebankPOS[labels.size()][]);
    
    int n = x.length;
    int k = 10;

    CrossValidation cv = new CrossValidation(n, k);
    int error = 0;
    int total = 0;
    
    for (int i = 0; i < k; i++) {
        String[][] trainx = Math.slice(x, cv.train[i]);
        PennTreebankPOS[][] trainy = Math.slice(y, cv.train[i]);
        String[][] testx = Math.slice(x, cv.test[i]);
        PennTreebankPOS[][] testy = Math.slice(y, cv.test[i]);

        HMMPOSTagger tagger = HMMPOSTagger.learn(trainx, trainy);

        for (int j = 0; j < testx.length; j++) {
            PennTreebankPOS[] label = tagger.tag(testx[j]);
            total += label.length;
            for (int l = 0; l < label.length; l++) {
                if (label[l] != testy[j][l]) {
                    error++;
                }
            }
        }
    }

    System.out.format("Error rate = %.2f as %d of %d\n", 100.0 * error / total, error, total);
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:40,代码来源:HMMPOSTaggerTest.java

示例2: testBrown

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class HMMPOSTagger.
 */
@Test
public void testBrown() {
    System.out.println("BROWN");
    load("D:\\sourceforge\\corpora\\PennTreebank\\PennTreebank2\\TAGGED\\POS\\BROWN");
    
    String[][] x = sentences.toArray(new String[sentences.size()][]);
    PennTreebankPOS[][] y = labels.toArray(new PennTreebankPOS[labels.size()][]);
    
    int n = x.length;
    int k = 10;

    CrossValidation cv = new CrossValidation(n, k);
    int error = 0;
    int total = 0;
    
    for (int i = 0; i < k; i++) {
        String[][] trainx = Math.slice(x, cv.train[i]);
        PennTreebankPOS[][] trainy = Math.slice(y, cv.train[i]);
        String[][] testx = Math.slice(x, cv.test[i]);
        PennTreebankPOS[][] testy = Math.slice(y, cv.test[i]);

        HMMPOSTagger tagger = HMMPOSTagger.learn(trainx, trainy);

        for (int j = 0; j < testx.length; j++) {
            PennTreebankPOS[] label = tagger.tag(testx[j]);
            total += label.length;
            for (int l = 0; l < label.length; l++) {
                if (label[l] != testy[j][l]) {
                    error++;
                }
            }
        }
    }

    System.out.format("Error rate = %.2f as %d of %d\n", 100.0 * error / total, error, total);
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:40,代码来源:HMMPOSTaggerTest.java

示例3: testLearnMultinomial

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class SequenceNaiveBayes.
 */
@Test
public void testLearnMultinomial() {
    System.out.println("batch learn Multinomial");

    double[][] x = moviex;
    int[] y = moviey;
    int n = x.length;
    int k = 10;
    CrossValidation cv = new CrossValidation(n, k);
    int error = 0;
    int total = 0;
    for (int i = 0; i < k; i++) {
        double[][] trainx = Math.slice(x, cv.train[i]);
        int[] trainy = Math.slice(y, cv.train[i]);
        NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.MULTINOMIAL, 2, feature.length);

        bayes.learn(trainx, trainy);

        double[][] testx = Math.slice(x, cv.test[i]);
        int[] testy = Math.slice(y, cv.test[i]);
        for (int j = 0; j < testx.length; j++) {
            int label = bayes.predict(testx[j]);
            if (label != -1) {
                total++;
                if (testy[j] != label) {
                    error++;
                }
            }
        }
    }

    System.out.format("Multinomial error = %d of %d%n", error, total);
    assertTrue(error < 265);
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:38,代码来源:NaiveBayesTest.java

示例4: testLearnMultinomial2

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class SequenceNaiveBayes.
 */
@Test
public void testLearnMultinomial2() {
    System.out.println("online learn Multinomial");

    double[][] x = moviex;
    int[] y = moviey;
    int n = x.length;
    int k = 10;
    CrossValidation cv = new CrossValidation(n, k);
    int error = 0;
    int total = 0;
    for (int i = 0; i < k; i++) {
        double[][] trainx = Math.slice(x, cv.train[i]);
        int[] trainy = Math.slice(y, cv.train[i]);
        NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.MULTINOMIAL, 2, feature.length);

        for (int j = 0; j < trainx.length; j++) {
            bayes.learn(trainx[j], trainy[j]);
        }

        double[][] testx = Math.slice(x, cv.test[i]);
        int[] testy = Math.slice(y, cv.test[i]);
        for (int j = 0; j < testx.length; j++) {
            int label = bayes.predict(testx[j]);
            if (label != -1) {
                total++;
                if (testy[j] != label) {
                    error++;
                }
            }
        }
    }

    System.out.format("Multinomial error = %d of %d%n", error, total);
    assertTrue(error < 265);
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:40,代码来源:NaiveBayesTest.java

示例5: testLearnBernoulli

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class SequenceNaiveBayes.
 */
@Test
public void testLearnBernoulli() {
    System.out.println("batch learn Bernoulli");

    double[][] x = moviex;
    int[] y = moviey;
    int n = x.length;
    int k = 10;
    CrossValidation cv = new CrossValidation(n, k);
    int error = 0;
    int total = 0;
    for (int i = 0; i < k; i++) {
        double[][] trainx = Math.slice(x, cv.train[i]);
        int[] trainy = Math.slice(y, cv.train[i]);
        NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.BERNOULLI, 2, feature.length);

        bayes.learn(trainx, trainy);

        double[][] testx = Math.slice(x, cv.test[i]);
        int[] testy = Math.slice(y, cv.test[i]);

        for (int j = 0; j < testx.length; j++) {
            int label = bayes.predict(testx[j]);
            if (label != -1) {
                total++;
                if (testy[j] != label) {
                    error++;
                }
            }
        }
    }

    System.out.format("Bernoulli error = %d of %d%n", error, total);
    assertTrue(error < 270);
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:39,代码来源:NaiveBayesTest.java

示例6: testLearnBernoulli2

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class SequenceNaiveBayes.
 */
@Test
public void testLearnBernoulli2() {
    System.out.println("online learn Bernoulli");

    double[][] x = moviex;
    int[] y = moviey;
    int n = x.length;
    int k = 10;
    CrossValidation cv = new CrossValidation(n, k);
    int error = 0;
    int total = 0;
    for (int i = 0; i < k; i++) {
        double[][] trainx = Math.slice(x, cv.train[i]);
        int[] trainy = Math.slice(y, cv.train[i]);
        NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.BERNOULLI, 2, feature.length);

        for (int j = 0; j < trainx.length; j++) {
            bayes.learn(trainx[j], trainy[j]);
        }

        double[][] testx = Math.slice(x, cv.test[i]);
        int[] testy = Math.slice(y, cv.test[i]);

        for (int j = 0; j < testx.length; j++) {
            int label = bayes.predict(testx[j]);
            if (label != -1) {
                total++;
                if (testy[j] != label) {
                    error++;
                }
            }
        }
    }

    System.out.format("Bernoulli error = %d of %d%n", error, total);
    assertTrue(error < 270);
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:41,代码来源:NaiveBayesTest.java

示例7: test

import smile.validation.CrossValidation; //导入依赖的package包/类
public void test(String dataset, String url, int response) {
    System.out.println(dataset);
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(response);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile(url));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        
        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        double ad = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            RegressionTree tree = new RegressionTree(data.attributes(), trainx, trainy, 20);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - tree.predict(testx[j]);
                rss += r * r;
                ad += Math.abs(r);
            }
        }

        System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss/n), ad/n);
     } catch (Exception ex) {
         System.err.println(ex);
     }
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:36,代码来源:RegressionTreeTest.java

示例8: test

import smile.validation.CrossValidation; //导入依赖的package包/类
public void test(GradientTreeBoost.Loss loss, String dataset, String url, int response) {
    System.out.println(dataset + "\t" + loss);
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(response);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile(url));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        
        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        double ad = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            GradientTreeBoost boost = new GradientTreeBoost(data.attributes(), trainx, trainy, loss, 100, 6, 0.05, 0.7);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - boost.predict(testx[j]);
                ad += Math.abs(r);
                rss += r * r;
            }
        }

        System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss/n), ad/n);
     } catch (Exception ex) {
         System.err.println(ex);
     }
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:36,代码来源:GradientTreeBoostTest.java

示例9: testCPU

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class SVR.
 */
@Test
public void testCPU() {
    System.out.println("CPU");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        Math.standardize(datax);

        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            SVR<double[]> svr = new SVR<>(trainx, trainy, new PolynomialKernel(3, 1.0, 1.0), 0.1, 1.0);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - svr.predict(testx[j]);
                rss += r * r;
            }
        }

        System.out.println("10-CV RMSE = " + Math.sqrt(rss / n));
     } catch (Exception ex) {
         System.err.println(ex);
     }
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:39,代码来源:SVRTest.java

示例10: testCPU

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class LinearRegression.
 */
@Test
public void testCPU() {
    System.out.println("CPU");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[][] datax = data.toArray(new double[data.size()][]);
        double[] datay = data.toArray(new double[data.size()]);

        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            RidgeRegression ridge = new RidgeRegression(trainx, trainy, 10.0);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - ridge.predict(testx[j]);
                rss += r * r;
            }
        }

        System.out.println("10-CV MSE = " + rss / n);
     } catch (Exception ex) {
         System.err.println(ex);
    }
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:38,代码来源:RidgeRegressionTest.java

示例11: test

import smile.validation.CrossValidation; //导入依赖的package包/类
public void test(String dataset, String url, int response) {
    System.out.println(dataset);
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(response);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile(url));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        
        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        double ad = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            RandomForest forest = new RandomForest(data.attributes(), trainx, trainy, 200, n, 5, trainx[0].length/3);
            System.out.format("OOB error rate = %.4f%n", forest.error());

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - forest.predict(testx[j]);
                rss += r * r;
                ad += Math.abs(r);
            }
        }

        System.out.format("10-CV RMSE = %.4f \t AbsoluteDeviation = %.4f%n", Math.sqrt(rss/n), ad/n);
     } catch (Exception ex) {
         System.err.println(ex);
     }
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:37,代码来源:RandomForestTest.java

示例12: testCPU

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class RBFNetwork.
 */
@Test
public void testCPU() {
    System.out.println("CPU");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        Math.standardize(datax);

        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            double[][] centers = new double[20][];
            RadialBasisFunction[] basis = SmileUtils.learnGaussianRadialBasis(trainx, centers, 5.0);
            RBFNetwork<double[]> rbf = new RBFNetwork<>(trainx, trainy, new EuclideanDistance(), basis, centers);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - rbf.predict(testx[j]);
                rss += r * r;
            }
        }

        System.out.println("10-CV MSE = " + rss / n);
     } catch (Exception ex) {
         System.err.println(ex);
     }
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:41,代码来源:RBFNetworkTest.java

示例13: test2DPlanes

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class RBFNetwork.
 */
@Test
public void test2DPlanes() {
    System.out.println("2dplanes");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(10);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/regression/2dplanes.arff"));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        //Math.normalize(datax);

        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            double[][] centers = new double[20][];
            RadialBasisFunction[] basis = SmileUtils.learnGaussianRadialBasis(trainx, centers, 5.0);
            RBFNetwork<double[]> rbf = new RBFNetwork<>(trainx, trainy, new EuclideanDistance(), basis, centers);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - rbf.predict(testx[j]);
                rss += r * r;
            }
        }

        System.out.println("10-CV MSE = " + rss / n);
     } catch (Exception ex) {
         System.err.println(ex);
     }
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:41,代码来源:RBFNetworkTest.java

示例14: testBank32nh

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class RBFNetwork.
 */
@Test
public void testBank32nh() {
    System.out.println("bank32nh");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(31);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/regression/bank32nh.arff"));
        double[] datay = data.toArray(new double[data.size()]);
        double[][] datax = data.toArray(new double[data.size()][]);
        Math.standardize(datax);

        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            double[][] centers = new double[20][];
            RadialBasisFunction[] basis = SmileUtils.learnGaussianRadialBasis(trainx, centers, 5.0);
            RBFNetwork<double[]> rbf = new RBFNetwork<>(trainx, trainy, new EuclideanDistance(), basis, centers);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - rbf.predict(testx[j]);
                rss += r * r;
            }
        }

        System.out.println("10-CV MSE = " + rss / n);
     } catch (Exception ex) {
         System.err.println(ex);
     }
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:41,代码来源:RBFNetworkTest.java

示例15: testCPU

import smile.validation.CrossValidation; //导入依赖的package包/类
/**
 * Test of learn method, of class LinearRegression.
 */
@Test
public void testCPU() {
    System.out.println("CPU");
    ArffParser parser = new ArffParser();
    parser.setResponseIndex(6);
    try {
        AttributeDataset data = parser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/cpu.arff"));

        double[][] datax = data.toArray(new double[data.size()][]);
        double[] datay = data.toArray(new double[data.size()]);

        int n = datax.length;
        int k = 10;

        CrossValidation cv = new CrossValidation(n, k);
        double rss = 0.0;
        for (int i = 0; i < k; i++) {
            double[][] trainx = Math.slice(datax, cv.train[i]);
            double[] trainy = Math.slice(datay, cv.train[i]);
            double[][] testx = Math.slice(datax, cv.test[i]);
            double[] testy = Math.slice(datay, cv.test[i]);

            OLS linear = new OLS(trainx, trainy);

            for (int j = 0; j < testx.length; j++) {
                double r = testy[j] - linear.predict(testx[j]);
                rss += r * r;
            }
        }

        System.out.println("MSE = " + rss / n);
    } catch (Exception ex) {
         System.err.println(ex);
    }
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:39,代码来源:OLSTest.java


注:本文中的smile.validation.CrossValidation类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。