当前位置: 首页>>代码示例>>Java>>正文


Java DataSetIterator类代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator类的具体用法?Java DataSetIterator怎么用?Java DataSetIterator使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


DataSetIterator类属于org.nd4j.linalg.dataset.api.iterator包,在下文中一共展示了DataSetIterator类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: createDataSource

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:27,代码来源:IrisFileDataSource.java

示例2: createDataSource

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 11;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
    DataSet allData = iterator.next();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:25,代码来源:DiabetesFileDataSource.java

示例3: getDataSetIterator

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
public static DataSetIterator getDataSetIterator(String DATA_PATH, boolean isTraining, WordVectors wordVectors, int minibatchSize,
                                                  int maxSentenceLength, Random rng ){
    String path = FilenameUtils.concat(DATA_PATH, (isTraining ? "aclImdb/train/" : "aclImdb/test/"));
    String positiveBaseDir = FilenameUtils.concat(path, "pos");
    String negativeBaseDir = FilenameUtils.concat(path, "neg");

    File filePositive = new File(positiveBaseDir);
    File fileNegative = new File(negativeBaseDir);

    Map<String,List<File>> reviewFilesMap = new HashMap<>();
    reviewFilesMap.put("Positive", Arrays.asList(filePositive.listFiles()));
    reviewFilesMap.put("Negative", Arrays.asList(fileNegative.listFiles()));

    LabeledSentenceProvider sentenceProvider = new FileLabeledSentenceProvider(reviewFilesMap, rng);

    return new CnnSentenceDataSetIterator.Builder()
            .sentenceProvider(sentenceProvider)
            .wordVectors(wordVectors)
            .minibatchSize(minibatchSize)
            .maxSentenceLength(maxSentenceLength)
            .useNormalizedWordVectors(false)
            .build();
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:24,代码来源:CNNSentenceClassification.java

示例4: getTrainingData

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
private static DataSetIterator getTrainingData(int batchSize, Random rand){
    double [] sum = new double[nSamples];
    double [] input1 = new double[nSamples];
    double [] input2 = new double[nSamples];
    for (int i= 0; i< nSamples; i++) {
        int MIN_RANGE = 0;
        int MAX_RANGE = 3;
        input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        input2[i] =  MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        sum[i] = input1[i] + input2[i];
    }
    INDArray inputNDArray1 = Nd4j.create(input1, new int[]{nSamples,1});
    INDArray inputNDArray2 = Nd4j.create(input2, new int[]{nSamples,1});
    INDArray inputNDArray = Nd4j.hstack(inputNDArray1,inputNDArray2);
    INDArray outPut = Nd4j.create(sum, new int[]{nSamples, 1});
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    List<DataSet> listDs = dataSet.asList();
    Collections.shuffle(listDs,rng);
    return new ListDataSetIterator(listDs,batchSize);

}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:22,代码来源:RegressionSum.java

示例5: evalMnistTestSet

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
private static void evalMnistTestSet(MultiLayerNetwork leNetModel) throws Exception {
	
       log.info("Load test data....");
       int batchSize = 64;
       DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
	
       log.info("Evaluate model....");
       int outputNum = 10;
       Evaluation eval = new Evaluation(outputNum);
	
       while(mnistTest.hasNext()){
           DataSet dataSet = mnistTest.next();
           INDArray output = leNetModel.output(dataSet.getFeatureMatrix(), false);
           eval.eval(dataSet.getLabels(), output);
       }
	
       log.info(eval.stats());
}
 
开发者ID:matthiaszimmermann,项目名称:ml_demo,代码行数:19,代码来源:LeNetMnistTester.java

示例6: getDataSetIterator

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/**
 * Returns the actual iterator.
 *
 * @param data the dataset to use
 * @param seed the seed for the random number generator
 * @param batchSize the batch size to use
 * @return the DataSetIterator
 */
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
    throws InvalidInputDataException, IOException {
  validate(data);
  initWordVectors();
  final LabeledSentenceProvider prov = getSentenceProvider(data);
  return new RnnTextEmbeddingDataSetIterator(
      data,
      wordVectors,
      tokenizerFactory,
      tokenPreProcess,
      stopwords,
      prov,
      batchSize,
      truncateLength);
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:25,代码来源:RnnTextEmbeddingInstanceIterator.java

示例7: getDataSetIterator

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
    throws InvalidInputDataException, IOException {
  validate(data);
  initWordVectors();
  final LabeledSentenceProvider sentenceProvider = getSentenceProvider(data);
  return new RnnTextEmbeddingDataSetIterator(
      data,
      wordVectors,
      tokenizerFactory,
      tokenPreProcess,
      stopwords,
      sentenceProvider,
      batchSize,
      truncateLength);
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:17,代码来源:RnnTextFilesEmbeddingInstanceIterator.java

示例8: getDataSetIterator

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/**
 * This method returns the iterator. Scales all intensity values: it divides them by 255.
 *
 * @param data the dataset to use
 * @param seed the seed for the random number generator
 * @param batchSize the batch size to use
 * @return the iterator
 * @throws Exception
 */
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
    throws Exception {

  batchSize = Math.min(data.numInstances(), batchSize);
  validate(data);
  ImageRecordReader reader = getImageRecordReader(data);

  final int labelIndex = 1; // Use explicit label index position
  final int numPossibleLabels = data.numClasses();
  DataSetIterator tmpIter =
      new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels);
  DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
  scaler.fit(tmpIter);
  tmpIter.setPreProcessor(scaler);
  return tmpIter;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:27,代码来源:ImageInstanceIterator.java

示例9: initEarlyStopping

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/**
 * Initialize early stopping with the given data
 *
 * @param data Data
 * @return Augmented data - if early stopping applies, return train set without validation set
 * @throws Exception
 */
protected Instances initEarlyStopping(Instances data) throws Exception {
  // Split train/validation
  double valSplit = earlyStopping.getValidationSetPercentage();
  Instances trainData;
  Instances valData;
  if (useEarlyStopping()) {
    // Split in train and validation
    Instances[] insts = splitTrainVal(data, valSplit);
    trainData = insts[0];
    valData = insts[1];
    validateSplit(trainData, valData);
    DataSetIterator valIterator = getDataSetIterator(valData, cacheMode, "val");
    earlyStopping.init(valIterator);
  } else {
    // Keep the full data
    trainData = data;
  }

  return trainData;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:28,代码来源:Dl4jMlpClassifier.java

示例10: testGetIteratorNominalClass

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNominalClass() throws Exception {
  final Instances data = DatasetLoader.loadAngerMetaClassification();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (int i = 0; i < data.size(); i++) {
    Instance inst = data.get(i);

    int label = Integer.parseInt(inst.stringValue(data.classIndex()));
    final DataSet next = it.next();
    int itLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  final Set<Integer> collect =
      it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
  Assert.assertEquals(2, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:24,代码来源:CnnTextFilesEmbeddingInstanceIteratorTest.java

示例11: testGetIteratorNumericClass

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNumericClass() throws Exception {
  final Instances data = DatasetLoader.loadAngerMeta();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Double> labels = new HashSet<>();
  for (int i = 0; i < data.size(); i++) {
    Instance inst = data.get(i);
    double label = inst.value(data.classIndex());
    final DataSet next = it.next();
    double itLabel = next.getLabels().getDouble(0);
    Assert.assertEquals(label, itLabel, 1e-5);
    labels.add(label);
  }
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:18,代码来源:CnnTextFilesEmbeddingInstanceIteratorTest.java

示例12: testGetIteratorNominalClass

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNominalClass() throws Exception {
  final Instances data = DatasetLoader.loadReutersMinimal();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (Instance inst : data) {
    int label = Integer.parseInt(inst.stringValue(data.classIndex()));
    final DataSet next = it.next();
    int itLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  final Set<Integer> collect =
      it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
  Assert.assertEquals(2, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:22,代码来源:CnnTextEmbeddingInstanceIteratorTest.java

示例13: testGetIteratorNumericClass

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNumericClass() throws Exception {
  final Instances data = makeData();
  final int batchSize = 1;
  final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);

  Set<Double> labels = new HashSet<>();
  for (int i = 0; i < data.size(); i++) {
    Instance inst = data.get(i);
    double label = inst.value(data.classIndex());
    final DataSet next = it.next();
    double itLabel = next.getLabels().getDouble(0);
    Assert.assertEquals(label, itLabel, 1e-5);
    labels.add(label);
  }
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:18,代码来源:CnnTextEmbeddingInstanceIteratorTest.java

示例14: testGetIterator

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIterator() throws Exception {
  final Instances metaData = DatasetLoader.loadMiniMnistMeta();
  this.idi.setImagesLocation(new File("datasets/nominal/mnist-minimal"));
  final int batchSize = 1;
  final DataSetIterator it = this.idi.getDataSetIterator(metaData, SEED, batchSize);

  Set<Integer> labels = new HashSet<>();
  for (Instance inst : metaData) {
    int label = Integer.parseInt(inst.stringValue(1));
    final DataSet next = it.next();
    int itLabel = next.getLabels().argMax().getInt(0);
    Assert.assertEquals(label, itLabel);
    labels.add(label);
  }
  final List<Integer> collect =
      it.getLabels().stream().map(Integer::valueOf).collect(Collectors.toList());
  Assert.assertEquals(10, labels.size());
  Assert.assertTrue(labels.containsAll(collect));
  Assert.assertTrue(collect.containsAll(labels));
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:23,代码来源:ImageInstanceIteratorTest.java

示例15: testOutputFormat

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
@Test
public void testOutputFormat() throws Exception {
  for (int tl : Arrays.asList(10, 50, 200)) {
    rii.setTruncateLength(tl);
    for (int bs : Arrays.asList(1, 4, 8, 16)) {
      final DataSetIterator it = rii.getDataSetIterator(data, TestUtil.SEED, bs);
      assertEquals(bs, it.batch());
      assertEquals(Arrays.asList("0.0", "1.0"), it.getLabels());
      final DataSet next = it.next();

      // Check feature shape, expect: (batchsize x wordvecsize x sequencelength)
      final int[] shapeFeats = next.getFeatures().shape();
      final int[] expShapeFeats = {bs, 6, tl};
      assertEquals(expShapeFeats[0],shapeFeats[0]);
      assertEquals(expShapeFeats[1],shapeFeats[1]);
      assertTrue(expShapeFeats[2] >= shapeFeats[2]);

      // Check label shape, expect: (batchsize x numclasses x sequencelength)
      final int[] shapeLabels = next.getLabels().shape();
      final int[] expShapeLabels = {bs, data.numClasses(), tl};
      assertEquals(expShapeLabels[0], shapeLabels[0]);
      assertEquals(expShapeLabels[1], shapeLabels[1]);
      assertTrue(expShapeLabels[2] >= shapeLabels[2]);
    }
  }
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:27,代码来源:RelationalInstanceIteratorTest.java


注:本文中的org.nd4j.linalg.dataset.api.iterator.DataSetIterator类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。