当前位置: 首页>>代码示例>>Java>>正文


Java SplitTestAndTrain类代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.SplitTestAndTrain的典型用法代码示例。如果您正苦于以下问题:Java SplitTestAndTrain类的具体用法?Java SplitTestAndTrain怎么用?Java SplitTestAndTrain使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


SplitTestAndTrain类属于org.nd4j.linalg.dataset包,在下文中一共展示了SplitTestAndTrain类的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: createDataSource

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:27,代码来源:IrisFileDataSource.java

示例2: createDataSource

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 11;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
    DataSet allData = iterator.next();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:25,代码来源:DiabetesFileDataSource.java

示例3: split

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
@Override
   public IrisData split(int trainingDataSize)
   {
final SplitTestAndTrain splits = dataset.splitTestAndTrain(trainingDataSize, new Random(seed));
trainingData = splits.getTrain();
testingData = splits.getTest();

return this;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:10,代码来源:IrisData.java

示例4: testIris2

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
@Test
public void testIris2() {
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(1e-1))
                    .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3)
                                    .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();

    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf,
                    Collections.<IterationListener>singletonList(new ScoreIterationListener(1)), 0, params, true);
    l.setBackpropGradientsViewArray(Nd4j.create(1, params.length()));
    DataSetIterator iter = new IrisDataSetIterator(150, 150);


    DataSet next = iter.next();
    next.shuffle();
    SplitTestAndTrain trainTest = next.splitTestAndTrain(110);
    trainTest.getTrain().normalizeZeroMeanZeroUnitVariance();
    for( int i=0; i<10; i++ ) {
        l.fit(trainTest.getTrain());
    }


    DataSet test = trainTest.getTest();
    test.normalizeZeroMeanZeroUnitVariance();
    Evaluation eval = new Evaluation();
    INDArray output = l.output(test.getFeatureMatrix());
    eval.eval(test.getLabels(), output);
    log.info("Score " + eval.stats());


}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:37,代码来源:OutputLayerTest.java

示例5: testWeightsDifferent

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
@Test
public void testWeightsDifferent() {
    Nd4j.MAX_ELEMENTS_PER_SLICE = Integer.MAX_VALUE;
    Nd4j.MAX_SLICES_TO_PRINT = Integer.MAX_VALUE;

    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
                    .miniBatch(false).seed(123)
                    .updater(new AdaGrad(1e-1))
                    .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3)
                                    .weightInit(WeightInit.XAVIER)
                                    .lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                                    .activation(Activation.SOFTMAX).build())
                    .build();

    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    OutputLayer o = (OutputLayer) conf.getLayer().instantiate(conf, null, 0, params, true);
    o.setBackpropGradientsViewArray(Nd4j.create(1, params.length()));


    int numSamples = 150;
    int batchSize = 150;


    DataSetIterator iter = new IrisDataSetIterator(batchSize, numSamples);
    DataSet iris = iter.next(); // Loads data into generator and format consumable for NN
    iris.normalizeZeroMeanZeroUnitVariance();
    o.setListeners(new ScoreIterationListener(1));
    SplitTestAndTrain t = iris.splitTestAndTrain(0.8);
    for( int i=0; i<1000; i++ ){
        o.fit(t.getTrain());
    }
    log.info("Evaluate model....");
    Evaluation eval = new Evaluation(3);
    eval.eval(t.getTest().getLabels(), o.output(t.getTest().getFeatureMatrix(), true));
    log.info(eval.stats());

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:39,代码来源:OutputLayerTest.java

示例6: testIris

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
@Test
public void testIris() {
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new Sgd(1e-1))
                    .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3)
                                    .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();

    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf,
                    Collections.<IterationListener>singletonList(new ScoreIterationListener(1)), 0, params, true);
    l.setBackpropGradientsViewArray(Nd4j.create(1, params.length()));
    DataSetIterator iter = new IrisDataSetIterator(150, 150);


    DataSet next = iter.next();
    next.shuffle();
    SplitTestAndTrain trainTest = next.splitTestAndTrain(110);
    trainTest.getTrain().normalizeZeroMeanZeroUnitVariance();
    for( int i=0; i<5; i++ ) {
        l.fit(trainTest.getTrain());
    }


    DataSet test = trainTest.getTest();
    test.normalizeZeroMeanZeroUnitVariance();
    Evaluation eval = new Evaluation();
    INDArray output = l.output(test.getFeatureMatrix());
    eval.eval(test.getLabels(), output);
    log.info("Score " + eval.stats());


}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:36,代码来源:OutputLayerTest.java

示例7: testBatchNorm

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
@Test
public void testBatchNorm() {
    Nd4j.getRandom().setSeed(123);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.TANH).build())
                    .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.TANH).build())
                    .layer(2, new BatchNormalization.Builder().nOut(2).build())
                    .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                    LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER)
                                                    .activation(Activation.SOFTMAX).nIn(2).nOut(3).build())
                    .backprop(true).pretrain(false).build();


    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    network.setListeners(new ScoreIterationListener(1));

    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    DataSet next = iter.next();
    next.normalizeZeroMeanZeroUnitVariance();
    SplitTestAndTrain trainTest = next.splitTestAndTrain(110);
    network.setLabels(trainTest.getTrain().getLabels());
    network.init();
    for( int i=0; i<5; i++ ) {
        network.fit(trainTest.getTrain());
    }

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:33,代码来源:MultiLayerTest.java

示例8: testBackProp

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
@Test
public void testBackProp() {
    Nd4j.getRandom().setSeed(123);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.TANH).build())
                    .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.TANH).build())
                    .layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                    LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER)
                                                    .activation(Activation.SOFTMAX).nIn(2).nOut(3).build())
                    .backprop(true).pretrain(false).build();


    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    network.setListeners(new ScoreIterationListener(1));

    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    DataSet next = iter.next();
    next.normalizeZeroMeanZeroUnitVariance();
    SplitTestAndTrain trainTest = next.splitTestAndTrain(110);
    network.setInput(trainTest.getTrain().getFeatureMatrix());
    network.setLabels(trainTest.getTrain().getLabels());
    network.init();
    for( int i=0; i<5; i++ ) {
        network.fit(trainTest.getTrain());
    }

    DataSet test = trainTest.getTest();
    Evaluation eval = new Evaluation();
    INDArray output = network.output(test.getFeatureMatrix());
    eval.eval(test.getLabels(), output);
    log.info("Score " + eval.stats());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:38,代码来源:MultiLayerTest.java

示例9: main

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
public static void main(String[] args) throws  Exception {

        //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
        int numLinesToSkip = 0;
        char delimiter = ',';
        RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
        recordReader.initialize(new FileSplit(new File("src/main/resources/DL4J_Resources/iris.txt")));

        //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
        int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
        int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
        int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

        DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
        DataSet allData = iterator.next();
        allData.shuffle();
        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

        DataSet trainingData = testAndTrain.getTrain();
        DataSet testData = testAndTrain.getTest();

        //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
        DataNormalization normalizer = new NormalizerStandardize();
        normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
        normalizer.transform(trainingData);     //Apply normalization to the training data
        normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set


        final int numInputs = 4;
        int outputNum = 3;
        int iterations = 1000;
        long seed = 6;


        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .iterations(iterations)
            .activation(Activation.TANH)
            .weightInit(WeightInit.XAVIER)
            .learningRate(0.1)
            .regularization(true).l2(1e-4)
            .list()
            .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
                .build())
            .layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
                .build())
            .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .activation(Activation.SOFTMAX)
                .nIn(3).nOut(outputNum).build())
            .backprop(true).pretrain(false)
            .build();

        //run the model
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(100));

        model.fit(trainingData);

        //evaluate the model on the test set
        Evaluation eval = new Evaluation(3);
        INDArray output = model.output(testData.getFeatureMatrix());
        eval.eval(testData.getLabels(), output);
        log.info(eval.stats());
        
        //Save the model
        File locationToSave = new File("src/main/resources/generatedModels/DL4J/DL4J_Iris_Model.zip");      //Where to save the network. Note: the file is in .zip format - can be opened externally
        boolean saveUpdater = true;                                             //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
        ModelSerializer.writeModel(model, locationToSave, saveUpdater);

        //Load the model
        MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(locationToSave);

        System.out.println("Saved and loaded parameters are equal:      " + model.params().equals(restored.params()));
        System.out.println("Saved and loaded configurations are equal:  " + model.getLayerWiseConfigurations().equals(restored.getLayerWiseConfigurations()));
    }
 
开发者ID:kaiwaehner,项目名称:kafka-streams-machine-learning-examples,代码行数:78,代码来源:DeepLearning4J_CSV_Model.java

示例10: main

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
public static void main(String[] args) throws IOException, InterruptedException {
    // source from: https://www.kaggle.com/uciml/glass
    int numLinesToSkip = 1;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
    recordReader.initialize(new FileSplit(new ClassPathResource("glass/glass.csv").getFile()));
    int labelIndex = 9;
    int numClasses = 7;
    int batchSize = 214; // totally, 214 data

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.8);
    DataSet trainingData = testAndTrain.getTrain();

    DataSet testData = testAndTrain.getTest();

    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);
    normalizer.transform(trainingData);
    normalizer.transform(testData);

    int seed = 123;
    int numInputs = 9;
    int iterations = 1000;
    int epochs = 1;

    log.info("Construct model...");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .iterations(iterations)
            .weightInit(WeightInit.XAVIER)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(Updater.NESTEROVS)
            .momentum(0.9)
            .learningRate(0.2)
            .regularization(true)
            .l2(1e-4)
            .list()
            .layer(0, new DenseLayer.Builder()
                    .nIn(numInputs)
                    .nOut(50)
                    .activation(Activation.TANH)
                    .build())
            .layer(1, new DenseLayer.Builder()
                    .nOut(100)
                    .activation(Activation.TANH)
                    .build())
            .layer(2, new OutputLayer.Builder()
                    .nOut(numClasses)
                    .lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                    .activation(Activation.SOFTMAX)
                    .build())
            .backprop(true)
            .pretrain(false)
            .build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(new ScoreIterationListener(100));
    for (int epoch = 0; epoch < epochs; epoch++) {
        model.fit(trainingData);
        log.info("*** Completed epoch {} ***", epoch);
        log.info("Evaluate model....");
        Evaluation eval = new Evaluation(numClasses);
        INDArray output = model.output(testData.getFeatureMatrix());
        eval.eval(testData.getLabels(), output);
        log.info(eval.stats());
    }

}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:74,代码来源:GlassClassification.java

示例11: main

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
    recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));

    int labelIndex = 4;
    int numClasses = 3;
    int batchSize = 150;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();
    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);

    DataSet trainingData = testAndTrain.getTrain();
    DataSet testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);//Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

    final int numInputs = 4;
    int outputNum = 3;
    int iterations = 1000;
    long seed = 6;
    int epochs = 100;

    log.info("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .iterations(iterations)
            .weightInit(WeightInit.XAVIER)
            .learningRate(0.1)
            .regularization(true)
            .l2(1e-4)
            .list()
            .layer(0, new DenseLayer.Builder()
                    .nIn(numInputs)
                    .nOut(20)
                    .activation(Activation.TANH)
                    .build())
            .layer(1, new DenseLayer.Builder()
                    .nOut(10)
                    .activation(Activation.TANH)
                    .build())
            .layer(2, new OutputLayer.Builder()
                    .lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                    .activation(Activation.SOFTMAX)
                    .nOut(outputNum)
                    .build())
            .backprop(true)
            .pretrain(false)
            .build();

    //run the model
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(new ScoreIterationListener(100));
    for (int epoch = 0; epoch < epochs; epoch++) {
        model.fit(trainingData);
    }

    //evaluate the model on the test set
    Evaluation eval = new Evaluation(numClasses);
    INDArray output = model.output(testData.getFeatureMatrix());
    eval.eval(testData.getLabels(), output);
    log.info(eval.stats());
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:72,代码来源:IrisClassification.java

示例12: testIris

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
@Test
public void testIris() {

    // Network config
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

                    .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42)
                    .updater(new Sgd(1e-6)).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH)
                                    .weightInit(WeightInit.XAVIER).build())
                    .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                    LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER)
                                                    .activation(Activation.SOFTMAX).build())

                    .build();

    // Instantiate model
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(1)));

    // Train-test split
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    DataSet next = iter.next();
    next.shuffle();
    SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42));

    // Train
    DataSet train = trainTest.getTrain();
    train.normalizeZeroMeanZeroUnitVariance();

    // Test
    DataSet test = trainTest.getTest();
    test.normalizeZeroMeanZeroUnitVariance();
    INDArray testFeature = test.getFeatureMatrix();
    INDArray testLabel = test.getLabels();

    // Fitting model
    model.fit(train);
    // Get predictions from test feature
    INDArray testPredictedLabel = model.output(testFeature);

    // Eval with class number
    Evaluation eval = new Evaluation(3); //// Specify class num here
    eval.eval(testLabel, testPredictedLabel);
    double eval1F1 = eval.f1();
    double eval1Acc = eval.accuracy();

    // Eval without class number
    Evaluation eval2 = new Evaluation(); //// No class num
    eval2.eval(testLabel, testPredictedLabel);
    double eval2F1 = eval2.f1();
    double eval2Acc = eval2.accuracy();

    //Assert the two implementations give same f1 and accuracy (since one batch)
    assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc);

    Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test)));
    checkEvaluationEquality(eval, evalViaMethod);

    System.out.println(eval.getConfusionMatrix().toString());
    System.out.println(eval.getConfusionMatrix().toCSV());
    System.out.println(eval.getConfusionMatrix().toHTML());

    System.out.println(eval.confusionToString());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:67,代码来源:EvalTest.java

示例13: splitTestAndTrain

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
/**
 * SplitV the DataSet into two DataSets randomly
 * @param percentTrain    Percentage of examples to be returned in the training DataSet object
 */
SplitTestAndTrain splitTestAndTrain(double percentTrain);
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:6,代码来源:DataSet.java

示例14: splitTestAndTrain

import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入依赖的package包/类
SplitTestAndTrain splitTestAndTrain(int numHoldout); 
开发者ID:wlin12,项目名称:JNN,代码行数:2,代码来源:DataSet.java


注:本文中的org.nd4j.linalg.dataset.SplitTestAndTrain类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。