当前位置: 首页>>代码示例>>Java>>正文


Java DataSetIterator.next方法代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.next方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.next方法的具体用法?Java DataSetIterator.next怎么用?Java DataSetIterator.next使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.dataset.api.iterator.DataSetIterator的用法示例。


在下文中一共展示了DataSetIterator.next方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: createDataSource

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:27,代码来源:IrisFileDataSource.java

示例2: createDataSource

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 11;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
    DataSet allData = iterator.next();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:25,代码来源:DiabetesFileDataSource.java

示例3: evalMnistTestSet

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private static void evalMnistTestSet(MultiLayerNetwork leNetModel) throws Exception {
	
       log.info("Load test data....");
       int batchSize = 64;
       DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
	
       log.info("Evaluate model....");
       int outputNum = 10;
       Evaluation eval = new Evaluation(outputNum);
	
       while(mnistTest.hasNext()){
           DataSet dataSet = mnistTest.next();
           INDArray output = leNetModel.output(dataSet.getFeatureMatrix(), false);
           eval.eval(dataSet.getLabels(), output);
       }
	
       log.info(eval.stats());
}
 
开发者ID:matthiaszimmermann,项目名称:ml_demo,代码行数:19,代码来源:LeNetMnistTester.java

示例4: testGetIteratorNominalClass

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNominalClass() throws Exception {
  final Instances data = DatasetLoader.loadAngerMetaClassification();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (int i = 0; i < data.size(); i++) {
    Instance inst = data.get(i);

    int label = Integer.parseInt(inst.stringValue(data.classIndex()));
    final DataSet next = it.next();
    int itLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  final Set<Integer> collect =
      it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
  Assert.assertEquals(2, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:24,代码来源:CnnTextFilesEmbeddingInstanceIteratorTest.java

示例5: testGetIteratorNumericClass

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNumericClass() throws Exception {
  final Instances data = DatasetLoader.loadAngerMeta();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Double> labels = new HashSet<>();
  for (int i = 0; i < data.size(); i++) {
    Instance inst = data.get(i);
    double label = inst.value(data.classIndex());
    final DataSet next = it.next();
    double itLabel = next.getLabels().getDouble(0);
    Assert.assertEquals(label, itLabel, 1e-5);
    labels.add(label);
  }
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:18,代码来源:CnnTextFilesEmbeddingInstanceIteratorTest.java

示例6: testGetIteratorNominalClass

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNominalClass() throws Exception {
  final Instances data = DatasetLoader.loadReutersMinimal();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (Instance inst : data) {
    int label = Integer.parseInt(inst.stringValue(data.classIndex()));
    final DataSet next = it.next();
    int itLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  final Set<Integer> collect =
      it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
  Assert.assertEquals(2, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:22,代码来源:CnnTextEmbeddingInstanceIteratorTest.java

示例7: testGetIteratorNumericClass

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNumericClass() throws Exception {
  final Instances data = makeData();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Double> labels = new HashSet<>();
  for (int i = 0; i < data.size(); i++) {
    Instance inst = data.get(i);
    double label = inst.value(data.classIndex());
    final DataSet next = it.next();
    double itLabel = next.getLabels().getDouble(0);
    Assert.assertEquals(label, itLabel, 1e-5);
    labels.add(label);
  }
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:18,代码来源:CnnTextEmbeddingInstanceIteratorTest.java

示例8: testGetIterator

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIterator() throws Exception {
  final Instances metaData = DatasetLoader.loadMiniMnistMeta();
  this.idi.setImagesLocation(new File("datasets/nominal/mnist-minimal"));
  final int batchSize = 1;
  final DataSetIterator it = this.idi.getDataSetIterator(metaData, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (Instance inst : metaData) {
    int label = Integer.parseInt(inst.stringValue(1));
    final DataSet next = it.next();
    int itLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  final List<Integer> collect =
      it.getLabels().stream().map(Integer::valueOf).collect(Collectors.toList());
  Assert.assertEquals(10, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:23,代码来源:ImageInstanceIteratorTest.java

示例9: testOutputFormat

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testOutputFormat() throws Exception {
  for (int tl : Arrays.asList(10, 50, 200)) {
    rii.setTruncateLength(tl);
    for (int bs : Arrays.asList(1, 4, 8, 16)) {
      final DataSetIterator it = rii.getDataSetIterator(data, TestUtil.SEED, bs);
      assertEquals(bs, it.batch());
      assertEquals(Arrays.asList("0.0", "1.0"), it.getLabels());
      final DataSet next = it.next();

      // Check feature shape, expect: (batchsize x wordvecsize x sequencelength)
      final int[] shapeFeats = next.getFeatures().shape();
      final int[] expShapeFeats = {bs, 6, tl};
      assertEquals(expShapeFeats[0],shapeFeats[0]);
      assertEquals(expShapeFeats[1],shapeFeats[1]);
      assertTrue(expShapeFeats[2] >= shapeFeats[2]);

      // Check label shape, expect: (batchsize x numclasses x sequencelength)
      final int[] shapeLabels = next.getLabels().shape();
      final int[] expShapeLabels = {bs, data.numClasses(), tl};
      assertEquals(expShapeLabels[0], shapeLabels[0]);
      assertEquals(expShapeLabels[1], shapeLabels[1]);
      assertTrue(expShapeLabels[2] >= shapeLabels[2]);
    }
  }
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:27,代码来源:RelationalInstanceIteratorTest.java

示例10: main

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: sentimentモデル名
 * args[2] input: test親フォルダ名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);

  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
  Evaluation evaluation = new Evaluation();
  while(test.hasNext()) {
    DataSet t = test.next();
    INDArray features = t.getFeatures();
    INDArray lables = t.getLabels();
    INDArray inMask = t.getFeaturesMaskArray();
    INDArray outMask = t.getLabelsMaskArray();
    INDArray predicted = model.output(features,false,inMask,outMask);
    evaluation.evalTimeSeries(lables,predicted,outMask);
  }
  System.out.println(evaluation.stats());
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:30,代码来源:SentimentRecurrentTestCmd.java

示例11: assertCachingDataSetIteratorHasAllTheData

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void assertCachingDataSetIteratorHasAllTheData(int rows, int inputColumns, int outputColumns,
                DataSet dataSet, DataSetIterator it, CachingDataSetIterator cachedIt) {
    cachedIt.reset();
    it.reset();

    dataSet.setFeatures(Nd4j.zeros(rows, inputColumns));
    dataSet.setLabels(Nd4j.ones(rows, outputColumns));

    while (it.hasNext()) {
        assertTrue(cachedIt.hasNext());

        DataSet cachedDs = cachedIt.next();
        assertEquals(1000.0, cachedDs.getFeatureMatrix().sumNumber());
        assertEquals(0.0, cachedDs.getLabels().sumNumber());

        DataSet ds = it.next();
        assertEquals(0.0, ds.getFeatureMatrix().sumNumber());
        assertEquals(20.0, ds.getLabels().sumNumber());
    }

    assertFalse(cachedIt.hasNext());
    assertFalse(it.hasNext());
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:24,代码来源:CachingDataSetIteratorTest.java

示例12: testItervsDataset

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public float testItervsDataset(DataNormalization preProcessor) {
    DataSet dataCopy = data.copy();
    DataSetIterator dataIter = new TestDataSetIterator(dataCopy, batchSize);
    preProcessor.fit(dataCopy);
    preProcessor.transform(dataCopy);
    INDArray transformA = dataCopy.getFeatures();

    preProcessor.fit(dataIter);
    dataIter.setPreProcessor(preProcessor);
    DataSet next = dataIter.next();
    INDArray transformB = next.getFeatures();

    while (dataIter.hasNext()) {
        next = dataIter.next();
        INDArray transformb = next.getFeatures();
        transformB = Nd4j.vstack(transformB, transformb);
    }

    return Transforms.abs(transformB.div(transformA).rsub(1)).maxNumber().floatValue();
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:21,代码来源:NormalizerTests.java

示例13: fit

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Fit the given model
 *
 * @param iterator for the data to iterate over
 */
@Override
public void fit(DataSetIterator iterator) {
    S.Builder featureNormBuilder = newBuilder();
    S.Builder labelNormBuilder = newBuilder();

    iterator.reset();
    while (iterator.hasNext()) {
        DataSet next = iterator.next();
        featureNormBuilder.addFeatures(next);
        if (fitLabels) {
            labelNormBuilder.addLabels(next);
        }
    }
    featureStats = (S) featureNormBuilder.build();
    if (fitLabels) {
        labelStats = (S) labelNormBuilder.build();
    }
    iterator.reset();
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:25,代码来源:AbstractDataSetNormalizer.java

示例14: testPredict

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testPredict() throws Exception {

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER).seed(12345L).list()
                    .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build())
                    .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nIn(50).nOut(10).build())
                    .pretrain(false).backprop(true).setInputType(InputType.convolutional(28, 28, 1)).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    DataSetIterator ds = new MnistDataSetIterator(10, 10);
    net.fit(ds);

    DataSetIterator testDs = new MnistDataSetIterator(1, 1);
    DataSet testData = testDs.next();
    testData.setLabelNames(Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"));
    String actualLables = testData.getLabelName(0);
    List<String> prediction = net.predict(testData);
    assertTrue(actualLables != null);
    assertTrue(prediction.get(0) != null);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:MultiLayerTest.java

示例15: output

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Label the probabilities of the input
 *
 * @param iterator test data to evaluate
 * @return a vector of probabilities
 * given each label.
 * <p>
 * This is typically of the form:
 * [0.5, 0.5] or some other probability distribution summing to one
 */
public INDArray output(DataSetIterator iterator, boolean train) {
    List<INDArray> outList = new ArrayList<>();
    while (iterator.hasNext()) {
        DataSet next = iterator.next();

        if (next.getFeatureMatrix() == null || next.getLabels() == null)
            break;

        INDArray features = next.getFeatures();

        if (next.hasMaskArrays()) {
            INDArray fMask = next.getFeaturesMaskArray();
            INDArray lMask = next.getLabelsMaskArray();
            outList.add(this.output(features, train, fMask, lMask));

        } else {
            outList.add(output(features, train));
        }
    }
    return Nd4j.vstack(outList.toArray(new INDArray[0]));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:32,代码来源:MultiLayerNetwork.java


注:本文中的org.nd4j.linalg.dataset.api.iterator.DataSetIterator.next方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。