当前位置: 首页>>代码示例>>Java>>正文


Java RecordReaderDataSetIterator类代码示例

本文整理汇总了Java中org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator的典型用法代码示例。如果您正苦于以下问题:Java RecordReaderDataSetIterator类的具体用法?Java RecordReaderDataSetIterator怎么用?Java RecordReaderDataSetIterator使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


RecordReaderDataSetIterator类属于org.deeplearning4j.datasets.datavec包,在下文中一共展示了RecordReaderDataSetIterator类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: createDataSource

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:27,代码来源:IrisFileDataSource.java

示例2: createDataSource

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 11;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
    DataSet allData = iterator.next();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:25,代码来源:DiabetesFileDataSource.java

示例3: getDataSetIterator

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
/**
 * This method returns the iterator. Scales all intensity values: it divides them by 255.
 *
 * @param data the dataset to use
 * @param seed the seed for the random number generator
 * @param batchSize the batch size to use
 * @return the iterator
 * @throws Exception
 */
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
    throws Exception {

  batchSize = Math.min(data.numInstances(), batchSize);
  validate(data);
  ImageRecordReader reader = getImageRecordReader(data);

  final int labelIndex = 1; // Use explicit label index position
  final int numPossibleLabels = data.numClasses();
  DataSetIterator tmpIter =
      new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels);
  DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
  scaler.fit(tmpIter);
  tmpIter.setPreProcessor(scaler);
  return tmpIter;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:27,代码来源:ImageInstanceIterator.java

示例4: testLRN

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testLRN() throws Exception {
    List<String> labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu"));
    String rootDir = new ClassPathResource("lfwtest").getFile().getAbsolutePath();

    RecordReader reader = new ImageRecordReader(28, 28, 3);
    reader.initialize(new FileSplit(new File(rootDir)));
    DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size());
    labels.remove("lfwtest");
    NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN();
    builder.setInputType(InputType.convolutional(28, 28, 3));

    MultiLayerConfiguration conf = builder.build();

    ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer();
    assertEquals(6, layer2.getNIn());

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ConvolutionLayerSetupTest.java

示例5: testNextAndReset

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testNextAndReset() throws Exception {
    int epochs = 3;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);

    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:MultipleEpochsIteratorTest.java

示例6: testLoadFullDataSet

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testLoadFullDataSet() throws Exception {
    int epochs = 3;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    DataSet ds = iter.next(50);

    assertEquals(50, ds.getFeatures().size(0));

    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);

    assertTrue(multiIter.hasNext());
    int count = 0;
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertNotNull(path);
        assertEquals(50, path.numExamples(), 0);
        count++;
    }
    assertEquals(epochs, count);
    assertEquals(epochs, multiIter.epochs);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:25,代码来源:MultipleEpochsIteratorTest.java

示例7: testLoadBatchDataSet

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testLoadBatchDataSet() throws Exception {
    int epochs = 2;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3);
    DataSet ds = iter.next(20);
    assertEquals(20, ds.getFeatures().size(0));
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);

    while (multiIter.hasNext()) {
        DataSet path = multiIter.next(10);
        assertNotNull(path);
        assertEquals(path.numExamples(), 10, 0.0);
    }

    assertEquals(epochs, multiIter.epochs);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:MultipleEpochsIteratorTest.java

示例8: testMnist

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testMnist() throws Exception {
    ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
    RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);

    MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);

    while (dsi.hasNext()) {
        DataSet dsExp = dsi.next();
        DataSet dsAct = iter.next();

        INDArray fExp = dsExp.getFeatureMatrix();
        fExp.divi(255);
        INDArray lExp = dsExp.getLabels();

        INDArray fAct = dsAct.getFeatureMatrix();
        INDArray lAct = dsAct.getLabels();

        assertEquals(fExp, fAct);
        assertEquals(lExp, lAct);
    }
    assertFalse(iter.hasNext());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:DataSetIteratorTest.java

示例9: processBatchIfRequired

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private void processBatchIfRequired(List<List<Writable>> list, boolean finalRecord) throws Exception {
    if (list.isEmpty())
        return;
    if (list.size() < batchSize && !finalRecord)
        return;

    RecordReader rr = new CollectionRecordReader(list);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(rr, new SelfWritableConverter(), batchSize,
                    labelIndex, numPossibleLabels, regression);

    DataSet ds = iter.next();

    String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin";

    URI uri = new URI(outputDir.getPath() + "/" + filename);
    FileSystem file = FileSystem.get(uri, conf);
    try (FSDataOutputStream out = file.create(new Path(uri))) {
        ds.save(out);
    }

    list.clear();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:StringToDataSetExportFunction.java

示例10: doPredict

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Override
    protected Object doPredict(List<String> line) {
        try {
            ListStringSplit input = new ListStringSplit(Collections.singletonList(line));
            ListStringRecordReader rr = new ListStringRecordReader();
            rr.initialize(input);
            DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1);

            DataSet ds = iterator.next();
            INDArray prediction = model.output(ds.getFeatures());

            DataType outputType = types.get(this.output);
            switch (outputType) {
                case _float : return prediction.getDouble(0);
                case _class: {
                    int numClasses = 2;
                    double max = 0;
                    int maxIndex = -1;
                    for (int i=0;i<numClasses;i++) {
                        if (prediction.getDouble(i) > max) {maxIndex = i; max = prediction.getDouble(i);}
                    }
                    return maxIndex;
//                    return prediction.getInt(0,1); // numberOfClasses
                }
                default: throw new IllegalArgumentException("Output type not yet supported "+outputType);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
 
开发者ID:neo4j-contrib,项目名称:neo4j-ml-procedures,代码行数:31,代码来源:DL4JMLModel.java

示例11: check

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private void check(BufferedImage image) throws Exception
{
    ImageIO.write(image, "png", new File("tmp.png")); //saves the image to the tmp.png file
    ImageRecordReader reader = new ImageRecordReader(150, 150, 3);
    reader.initialize(new FileSplit(new File("tmp.png"))); //reads the tmp.png file
    DataSetIterator dataIter = new RecordReaderDataSetIterator(reader, 1);
    while (dataIter.hasNext())
    {
        //Normalize the data from the file
        DataNormalization normalization = new NormalizerMinMaxScaler();
        DataSet set = dataIter.next();
        normalization.fit(set);
        normalization.transform(set);

        INDArray array = MainGUI.model.output(set.getFeatures(), false); //send the data to the model and get the results

        //Process the results and print them in an understandable format (percentage scores)
        String txt = "";

        DecimalFormat df = new DecimalFormat("#.00");

        for (int i = 0; i < array.length(); i++)
        {
            txt += MainGUI.labels.get(i) + ": " + (array.getDouble(i)*100 < 1 ? "0" : "") + df.format((array.getDouble(i)*100)) + "%\n";
        }

        probabilityArea.setText(txt);
    }

    reader.close();
}
 
开发者ID:maksgraczyk,项目名称:DeepID,代码行数:32,代码来源:Identification.java

示例12: createInternal

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private DataSetIterator createInternal(InputSplit inputSplit) throws IOException {
    ImageTransform imageTransform = imageTransformFactory.create();
    int width = imageTransformConfigurationResource.getScaledWidth();
    int height = imageTransformConfigurationResource.getScaledHeight();
    int channels = imageTransformConfigurationResource.getChannels();
    int batchSize = networkConfigurationResource.getBatchSize();
    int outputs = networkConfigurationResource.getOutputs();
    ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, pathLabelGenerator);
    recordReader.initialize(inputSplit, imageTransform);
    RecordReaderDataSetIterator recordReaderDataSetIterator = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputs);
    DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
    scaler.fit(recordReaderDataSetIterator);
    recordReaderDataSetIterator.setPreProcessor(scaler);
    return recordReaderDataSetIterator;
}
 
开发者ID:scaliby,项目名称:ceidg-captcha,代码行数:16,代码来源:DataSetIteratorFactoryImpl.java

示例13: readCSVDataset

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private static DataSetIterator readCSVDataset(String csvFileClasspath, int BATCH_SIZE, int LABEL_INDEX, int numClasses)
        throws IOException, InterruptedException {

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new File(csvFileClasspath)));
    DataSetIterator iterator = new RecordReaderDataSetIterator(rr, BATCH_SIZE, LABEL_INDEX, numClasses);

    return iterator;
}
 
开发者ID:emara-geek,项目名称:arabic-characters-recognition,代码行数:10,代码来源:ModelGenerator.java

示例14: irisCsv

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
static DataIterator<NormalizerStandardize> irisCsv(String name) {
    CSVRecordReader recordReader = new CSVRecordReader(0, ",");
    try {
        recordReader.initialize(new FileSplit(new File(name)));
    } catch (Exception e) {
        e.printStackTrace();
    }

    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
    int batchSize = 50;     //Iris data set: 150 examples total.

    RecordReaderDataSetIterator iterator = new RecordReaderDataSetIterator(
            recordReader,
            batchSize,
            labelIndex,
            numClasses
    );

    NormalizerStandardize normalizer = new NormalizerStandardize();

    while (iterator.hasNext()) {
        normalizer.fit(iterator.next());
    }
    iterator.reset();

    iterator.setPreProcessor(normalizer);

    return new DataIterator<>(iterator, normalizer);
}
 
开发者ID:wmeddie,项目名称:dl4j-trainer-archetype,代码行数:31,代码来源:DataIterator.java

示例15: main

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
public static void main(String... args) throws Exception {
    Options options = new Options();

    options.addOption("i", "input", true, "The file with test data.");
    options.addOption("m", "model", true, "Name of trained model file.");

    CommandLine cmd = new BasicParser().parse(options, args);

    String input = cmd.getOptionValue("i");
    String modelName = cmd.getOptionValue("m");

    if (cmd.hasOption("i") && cmd.hasOption("m")) {
        MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelName);
        DataIterator<NormalizerStandardize> it = DataIterator.irisCsv(input);
        RecordReaderDataSetIterator testData = it.getIterator();
        NormalizerStandardize normalizer = it.getNormalizer();
        normalizer.load(
                new File(modelName + ".norm1"),
                new File(modelName + ".norm2"),
                new File(modelName + ".norm3"),
                new File(modelName + ".norm4")
        );

        Evaluation eval = new Evaluation(3);
        while (testData.hasNext()) {
            DataSet ds = testData.next();
            INDArray output = model.output(ds.getFeatureMatrix());
            eval.eval(ds.getLabels(), output);
        }

        log.info(eval.stats());
    } else {
        log.error("Invalid arguments.");

        new HelpFormatter().printHelp("Evaluate", options);
    }
}
 
开发者ID:wmeddie,项目名称:dl4j-trainer-archetype,代码行数:38,代码来源:Evaluate.java


注:本文中的org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。