本文整理汇总了Java中org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator类的典型用法代码示例。如果您正苦于以下问题:Java MnistDataSetIterator类的具体用法?Java MnistDataSetIterator怎么用?Java MnistDataSetIterator使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
MnistDataSetIterator类属于org.deeplearning4j.datasets.iterator.impl包,在下文中一共展示了MnistDataSetIterator类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: evalMnistTestSet
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
private static void evalMnistTestSet(MultiLayerNetwork leNetModel) throws Exception {
log.info("Load test data....");
int batchSize = 64;
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
log.info("Evaluate model....");
int outputNum = 10;
Evaluation eval = new Evaluation(outputNum);
while(mnistTest.hasNext()){
DataSet dataSet = mnistTest.next();
INDArray output = leNetModel.output(dataSet.getFeatureMatrix(), false);
eval.eval(dataSet.getLabels(), output);
}
log.info(eval.stats());
}
示例2: evaluate
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Override
@SuppressWarnings("rawtypes")
public Model evaluate()
{
final Evaluation evaluation = new Evaluation(parameters.getOutputSize());
try
{
final DataSetIterator iterator = new MnistDataSetIterator(100, 10000);
while (iterator.hasNext())
{
final DataSet testingData = iterator.next();
evaluation.eval(testingData.getLabels(), model.output(testingData.getFeatureMatrix()));
}
System.out.println(evaluation.stats());
}
catch (IOException e)
{
e.printStackTrace();
}
return this;
}
示例3: testMultiCNNLayer
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testMultiCNNLayer() throws Exception {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list()
.layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new LocalResponseNormalization.Builder().build()).layer(2,
new DenseLayer.Builder()
.nOut(2).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(10)
.build())
.backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next();
network.fit(next);
}
示例4: testCNNBNActivationCombo
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testCNNBNActivationCombo() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().build())
.layer(2, new ActivationLayer.Builder().activation(Activation.RELU).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build())
.backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
network.fit(next);
assertNotEquals(null, network.getLayer(0).getParam("W"));
assertNotEquals(null, network.getLayer(0).getParam("b"));
}
示例5: testMNISTConfig
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
@Ignore //Should be run manually
public void testMNISTConfig() throws Exception {
int batchSize = 64; // Test batch size
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
ComputationGraph net = getCNNMnistConfig();
net.init();
net.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < 50; i++) {
net.fit(mnistTrain.next());
Thread.sleep(1000);
}
Thread.sleep(100000);
}
示例6: testPredict
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testPredict() throws Exception {
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER).seed(12345L).list()
.layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(50).nOut(10).build())
.pretrain(false).backprop(true).setInputType(InputType.convolutional(28, 28, 1)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator ds = new MnistDataSetIterator(10, 10);
net.fit(ds);
DataSetIterator testDs = new MnistDataSetIterator(1, 1);
DataSet testData = testDs.next();
testData.setLabelNames(Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"));
String actualLables = testData.getLabelName(0);
List<String> prediction = net.predict(testData);
assertTrue(actualLables != null);
assertTrue(prediction.get(0) != null);
}
示例7: mnistTrainSetIterator
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
public static DataSetIterator mnistTrainSetIterator() {
try {
return new MnistDataSetIterator(BATCH_SIZE, true, 12345);
} catch (IOException e) {
throw new RuntimeException("Couldn't build the MnistDataSetIterator");
}
}
示例8: mnistTestSetIterator
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
public static DataSetIterator mnistTestSetIterator() {
try {
return new MnistDataSetIterator(BATCH_SIZE, false, 12345);
} catch (IOException e) {
throw new RuntimeException("Couldn't build the MnistDataSetIterator");
}
}
示例9: main
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
final int numRows = 28;
final int numColumns = 28;
int seed = 123;
int numSamples = MnistDataFetcher.NUM_EXAMPLES;
int batchSize = 1000;
int iterations = 1;
int listenerFreq = iterations/5;
log.info("Load data....");
DataSetIterator iter = new MnistDataSetIterator(batchSize,numSamples,true);
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.list(8)
.layer(0, new RBM.Builder().nIn(numRows * numColumns).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(1, new RBM.Builder().nIn(2000).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(2, new RBM.Builder().nIn(1000).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(3, new RBM.Builder().nIn(500).nOut(30).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(4, new RBM.Builder().nIn(30).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(5, new RBM.Builder().nIn(500).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(6, new RBM.Builder().nIn(1000).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(7, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nIn(2000).nOut(numRows*numColumns).build())
.pretrain(true).backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(listenerFreq));
log.info("Train model....");
while(iter.hasNext()) {
DataSet next = iter.next();
model.fit(new DataSet(next.getFeatureMatrix(),next.getFeatureMatrix()));
}
}
示例10: load
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Override
public MnistData load()
{
try
{
iterator = new MnistDataSetIterator(batchSize, nSamples, binarize);
}
catch (IOException e)
{
e.printStackTrace();
}
return this;
}
示例11: main
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
/**
* @param args
*/
public static void main(String[] args) throws Exception {
MnistDataSetIterator iter = new MnistDataSetIterator(60,60000);
@SuppressWarnings("unchecked")
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));
BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
DataSet test = null;
while(iter.hasNext()) {
INDArray reconstructed = network.transform(test.getFeatureMatrix());
for(int i = 0; i < test.numExamples(); i++) {
INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
INDArray reconstructed2 = reconstructed.getRow(i);
INDArray draw2 = Sampling.binomial(reconstructed2, 1, new MersenneTwister(123)).mul(255);
DrawReconstruction d = new DrawReconstruction(draw1);
d.title = "REAL";
d.draw();
DrawReconstruction d2 = new DrawReconstruction(draw2,100,100);
d2.title = "TEST";
d2.draw();
Thread.sleep(10000);
d.frame.dispose();
d2.frame.dispose();
}
}
}
示例12: main
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
/**
* @param args
*/
public static void main(String[] args) throws Exception {
MnistDataSetIterator iter = new MnistDataSetIterator(60, 60000);
@SuppressWarnings("unchecked")
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));
BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
try {
ois.close();
} catch (IOException e) {
}
DataSet test = null;
while (iter.hasNext()) {
test = iter.next();
INDArray reconstructed = network.activate(test.getFeatureMatrix());
for (int i = 0; i < test.numExamples(); i++) {
INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
INDArray reconstructed2 = reconstructed.getRow(i);
INDArray draw2 = Nd4j.getDistributions().createBinomial(1, reconstructed2)
.sample(reconstructed2.shape()).mul(255);
DrawReconstruction d = new DrawReconstruction(draw1);
d.title = "REAL";
d.draw();
DrawReconstruction d2 = new DrawReconstruction(draw2, 100, 100);
d2.title = "TEST";
d2.draw();
Thread.sleep(10000);
d.frame.dispose();
d2.frame.dispose();
}
}
}
示例13: testClassificationScoreFunctionSimple
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testClassificationScoreFunctionSimple() throws Exception {
for(Evaluation.Metric metric : Evaluation.Metric.values()) {
log.info("Metric: " + metric);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
.layer(new OutputLayer.Builder().nIn(32).nOut(10).activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);
List<DataSet> l = new ArrayList<>();
for( int i=0; i<10; i++ ){
DataSet ds = iter.next();
l.add(ds);
}
iter = new ExistingDataSetIterator(l);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new ClassificationScoreCalculator(metric, iter)).modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
assertNotNull(result.getBestModel());
}
}
示例14: testDBNBNMultiLayer
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void testDBNBNMultiLayer() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next();
// Run with separate activation layer
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new BatchNormalization.Builder().nOut(10).build()).layer(2,
new ActivationLayer.Builder()
.activation(Activation.RELU).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(10).nOut(10)
.build())
.backprop(true).pretrain(false).build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
network.setInput(next.getFeatureMatrix());
INDArray activationsActual = network.preOutput(next.getFeatureMatrix());
assertEquals(10, activationsActual.shape()[1], 1e-2);
network.fit(next);
INDArray actualGammaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.GAMMA);
INDArray actualBetaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.BETA);
assertTrue(actualGammaParam != null);
assertTrue(actualBetaParam != null);
}
示例15: checkSerialization
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入依赖的package包/类
@Test
public void checkSerialization() throws Exception {
//Serialize the batch norm network (after training), and make sure we get same activations out as before
// i.e., make sure state is properly stored
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.list()
.layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER)
.activation(Activation.IDENTITY).build())
.layer(1, new BatchNormalization.Builder().build())
.layer(2, new ActivationLayer.Builder().activation(Activation.LEAKYRELU).build())
.layer(3, new DenseLayer.Builder().nOut(10).activation(Activation.LEAKYRELU).build())
.layer(4, new BatchNormalization.Builder().build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build())
.backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
for (int i = 0; i < 20; i++) {
net.fit(iter.next());
}
INDArray in = iter.next().getFeatureMatrix();
INDArray out = net.output(in, false);
INDArray out2 = net.output(in, false);
assertEquals(out, out2);
MultiLayerNetwork net2 = TestUtils.testModelSerialization(net);
INDArray outDeser = net2.output(in, false);
assertEquals(out, outDeser);
}