当前位置: 首页>>代码示例>>Java>>正文


Java MnistDataSetIterator类代码示例

本文整理汇总了Java中org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator的典型用法代码示例。如果您正苦于以下问题:Java MnistDataSetIterator类的具体用法?Java MnistDataSetIterator怎么用?Java MnistDataSetIterator使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


MnistDataSetIterator类属于org.deeplearning4j.datasets.iterator.impl包,在下文中一共展示了MnistDataSetIterator类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: evalMnistTestSet

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
private static void evalMnistTestSet(MultiLayerNetwork leNetModel) throws Exception {
	
       log.info("Load test data....");
       int batchSize = 64;
       DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
	
       log.info("Evaluate model....");
       int outputNum = 10;
       Evaluation eval = new Evaluation(outputNum);
	
       while(mnistTest.hasNext()){
           DataSet dataSet = mnistTest.next();
           INDArray output = leNetModel.output(dataSet.getFeatureMatrix(), false);
           eval.eval(dataSet.getLabels(), output);
       }
	
       log.info(eval.stats());
}
 
开发者ID:matthiaszimmermann,项目名称:ml_demo,代码行数:19,代码来源:LeNetMnistTester.java

示例2: evaluate

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Override
   @SuppressWarnings("rawtypes")
   public Model evaluate()
   {
final Evaluation evaluation = new Evaluation(parameters.getOutputSize());
try
{
    final DataSetIterator iterator = new MnistDataSetIterator(100, 10000);
    while (iterator.hasNext())
    {
	final DataSet testingData = iterator.next();
	evaluation.eval(testingData.getLabels(), model.output(testingData.getFeatureMatrix()));
    }

    System.out.println(evaluation.stats());
}
catch (IOException e)
{
    e.printStackTrace();
}
return this;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:23,代码来源:StackedAutoEncoderModel.java

示例3: testMultiCNNLayer

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testMultiCNNLayer() throws Exception {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list()
                    .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.RELU).build())
                    .layer(1, new LocalResponseNormalization.Builder().build()).layer(2,
                                    new DenseLayer.Builder()
                                                    .nOut(2).build())
                    .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(10)
                                    .build())
                    .backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();

    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    DataSetIterator iter = new MnistDataSetIterator(2, 2);
    DataSet next = iter.next();

    network.fit(next);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:LocalResponseTest.java

示例4: testCNNBNActivationCombo

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testCNNBNActivationCombo() throws Exception {
    DataSetIterator iter = new MnistDataSetIterator(2, 2);
    DataSet next = iter.next();

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
                    .list()
                    .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.IDENTITY).build())
                    .layer(1, new BatchNormalization.Builder().build())
                    .layer(2, new ActivationLayer.Builder().activation(Activation.RELU).build())
                    .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build())
                    .backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();

    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    network.fit(next);

    assertNotEquals(null, network.getLayer(0).getParam("W"));
    assertNotEquals(null, network.getLayer(0).getParam("b"));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:24,代码来源:BatchNormalizationTest.java

示例5: testMNISTConfig

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
@Ignore //Should be run manually
public void testMNISTConfig() throws Exception {
    int batchSize = 64; // Test batch size
    DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);

    ComputationGraph net = getCNNMnistConfig();
    net.init();
    net.setListeners(new ScoreIterationListener(1));

    for (int i = 0; i < 50; i++) {
        net.fit(mnistTrain.next());
        Thread.sleep(1000);
    }

    Thread.sleep(100000);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:CenterLossOutputLayerTest.java

示例6: testPredict

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testPredict() throws Exception {

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER).seed(12345L).list()
                    .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build())
                    .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nIn(50).nOut(10).build())
                    .pretrain(false).backprop(true).setInputType(InputType.convolutional(28, 28, 1)).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    DataSetIterator ds = new MnistDataSetIterator(10, 10);
    net.fit(ds);

    DataSetIterator testDs = new MnistDataSetIterator(1, 1);
    DataSet testData = testDs.next();
    testData.setLabelNames(Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"));
    String actualLables = testData.getLabelName(0);
    List<String> prediction = net.predict(testData);
    assertTrue(actualLables != null);
    assertTrue(prediction.get(0) != null);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:MultiLayerTest.java

示例7: mnistTrainSetIterator

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
public static DataSetIterator mnistTrainSetIterator() {
	try {
		return new MnistDataSetIterator(BATCH_SIZE, true, 12345);
	} catch (IOException e) {
		throw new RuntimeException("Couldn't build the MnistDataSetIterator");
	}
}
 
开发者ID:braeunlich,项目名称:anagnostes,代码行数:8,代码来源:ConfigurationFactory.java

示例8: mnistTestSetIterator

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
public static DataSetIterator mnistTestSetIterator() {
	try {
		return new MnistDataSetIterator(BATCH_SIZE, false, 12345);
	} catch (IOException e) {
		throw new RuntimeException("Couldn't build the MnistDataSetIterator");
	}
}
 
开发者ID:braeunlich,项目名称:anagnostes,代码行数:8,代码来源:ConfigurationFactory.java

示例9: main

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
    final int numRows = 28;
    final int numColumns = 28;
    int seed = 123;
    int numSamples = MnistDataFetcher.NUM_EXAMPLES;
    int batchSize = 1000;
    int iterations = 1;
    int listenerFreq = iterations/5;

    log.info("Load data....");
    DataSetIterator iter = new MnistDataSetIterator(batchSize,numSamples,true);

    log.info("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .iterations(iterations)
            .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
            .list(8)
            .layer(0, new RBM.Builder().nIn(numRows * numColumns).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(1, new RBM.Builder().nIn(2000).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(2, new RBM.Builder().nIn(1000).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(3, new RBM.Builder().nIn(500).nOut(30).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(4, new RBM.Builder().nIn(30).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) 
            .layer(5, new RBM.Builder().nIn(500).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(6, new RBM.Builder().nIn(1000).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(7, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nIn(2000).nOut(numRows*numColumns).build())
            .pretrain(true).backprop(true)
            .build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    model.setListeners(new ScoreIterationListener(listenerFreq));

    log.info("Train model....");
    while(iter.hasNext()) {
        DataSet next = iter.next();
        model.fit(new DataSet(next.getFeatureMatrix(),next.getFeatureMatrix()));
    }
}
 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-Hadoop,代码行数:41,代码来源:DeepAutoEncoder.java

示例10: load

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Override
   public MnistData load()
   {
try
{
    iterator = new MnistDataSetIterator(batchSize, nSamples, binarize);
}
catch (IOException e)
{
    e.printStackTrace();
}
return this;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:14,代码来源:MnistData.java

示例11: main

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
/**
 * @param args
 */
public static void main(String[] args) throws Exception {
	MnistDataSetIterator iter = new MnistDataSetIterator(60,60000);
	@SuppressWarnings("unchecked")
	ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));
	
	BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
	
	
	DataSet test = null;
	while(iter.hasNext()) {
		INDArray reconstructed = network.transform(test.getFeatureMatrix());
		for(int i = 0; i < test.numExamples(); i++) {
			INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
			INDArray reconstructed2 = reconstructed.getRow(i);
			INDArray draw2 = Sampling.binomial(reconstructed2, 1, new MersenneTwister(123)).mul(255);

			DrawReconstruction d = new DrawReconstruction(draw1);
			d.title = "REAL";
			d.draw();
			DrawReconstruction d2 = new DrawReconstruction(draw2,100,100);
			d2.title = "TEST";
			d2.draw();
			Thread.sleep(10000);
			d.frame.dispose();
			d2.frame.dispose();
		}
	}
	
	
}
 
开发者ID:jpatanooga,项目名称:Canova,代码行数:34,代码来源:LoadAndDraw.java

示例12: main

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
/**
 * @param args
 */
public static void main(String[] args) throws Exception {
    MnistDataSetIterator iter = new MnistDataSetIterator(60, 60000);
    @SuppressWarnings("unchecked")
    ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));

    BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
    try {
        ois.close();
    } catch (IOException e) {
    }

    DataSet test = null;
    while (iter.hasNext()) {
        test = iter.next();
        INDArray reconstructed = network.activate(test.getFeatureMatrix());
        for (int i = 0; i < test.numExamples(); i++) {
            INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
            INDArray reconstructed2 = reconstructed.getRow(i);
            INDArray draw2 = Nd4j.getDistributions().createBinomial(1, reconstructed2)
                            .sample(reconstructed2.shape()).mul(255);

            DrawReconstruction d = new DrawReconstruction(draw1);
            d.title = "REAL";
            d.draw();
            DrawReconstruction d2 = new DrawReconstruction(draw2, 100, 100);
            d2.title = "TEST";
            d2.draw();
            Thread.sleep(10000);
            d.frame.dispose();
            d2.frame.dispose();
        }
    }


}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:39,代码来源:LoadAndDraw.java

示例13: testClassificationScoreFunctionSimple

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testClassificationScoreFunctionSimple() throws Exception {

    for(Evaluation.Metric metric : Evaluation.Metric.values()) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
                .layer(new OutputLayer.Builder().nIn(32).nOut(10).activation(Activation.SOFTMAX).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(ds);
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new ClassificationScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

        assertNotNull(result.getBestModel());
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:41,代码来源:TestEarlyStopping.java

示例14: testDBNBNMultiLayer

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testDBNBNMultiLayer() throws Exception {
    DataSetIterator iter = new MnistDataSetIterator(2, 2);
    DataSet next = iter.next();

    // Run with separate activation layer
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
                    .list()
                    .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.RELU).build())
                    .layer(1, new BatchNormalization.Builder().nOut(10).build()).layer(2,
                                    new ActivationLayer.Builder()
                                                    .activation(Activation.RELU).build())
                    .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10)
                                    .build())
                    .backprop(true).pretrain(false).build();

    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();

    network.setInput(next.getFeatureMatrix());
    INDArray activationsActual = network.preOutput(next.getFeatureMatrix());
    assertEquals(10, activationsActual.shape()[1], 1e-2);

    network.fit(next);
    INDArray actualGammaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.GAMMA);
    INDArray actualBetaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.BETA);
    assertTrue(actualGammaParam != null);
    assertTrue(actualBetaParam != null);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:33,代码来源:BatchNormalizationTest.java

示例15: checkSerialization

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void checkSerialization() throws Exception {
    //Serialize the batch norm network (after training), and make sure we get same activations out as before
    // i.e., make sure state is properly stored

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .seed(12345)
                    .list()
                    .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.IDENTITY).build())
                    .layer(1, new BatchNormalization.Builder().build())
                    .layer(2, new ActivationLayer.Builder().activation(Activation.LEAKYRELU).build())
                    .layer(3, new DenseLayer.Builder().nOut(10).activation(Activation.LEAKYRELU).build())
                    .layer(4, new BatchNormalization.Builder().build())
                    .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build())
                    .backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
    for (int i = 0; i < 20; i++) {
        net.fit(iter.next());
    }

    INDArray in = iter.next().getFeatureMatrix();

    INDArray out = net.output(in, false);
    INDArray out2 = net.output(in, false);

    assertEquals(out, out2);

    MultiLayerNetwork net2 = TestUtils.testModelSerialization(net);

    INDArray outDeser = net2.output(in, false);

    assertEquals(out, outDeser);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:40,代码来源:BatchNormalizationTest.java


注:本文中的org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。