当前位置: 首页>>代码示例>>Java>>正文


Java RecordReader.initialize方法代码示例

本文整理汇总了Java中org.datavec.api.records.reader.RecordReader.initialize方法的典型用法代码示例。如果您正苦于以下问题:Java RecordReader.initialize方法的具体用法?Java RecordReader.initialize怎么用?Java RecordReader.initialize使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.datavec.api.records.reader.RecordReader的用法示例。


在下文中一共展示了RecordReader.initialize方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: createDataSource

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:27,代码来源:IrisFileDataSource.java

示例2: createDataSource

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 11;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
    DataSet allData = iterator.next();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:25,代码来源:DiabetesFileDataSource.java

示例3: testCsvRRSerializationResults

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testCsvRRSerializationResults() throws Exception {
    int skipLines = 3;
    RecordReader r1 = new CSVRecordReader(skipLines, '\t');
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    ObjectOutputStream os = new ObjectOutputStream(baos);
    os.writeObject(r1);
    byte[] bytes = baos.toByteArray();
    ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes));
    RecordReader r2 = (RecordReader) ois.readObject();

    File f = new ClassPathResource("iris_tab_delim.txt").getFile();

    r1.initialize(new FileSplit(f));
    r2.initialize(new FileSplit(f));

    int count = 0;
    while(r1.hasNext()){
        List<Writable> n1 = r1.next();
        List<Writable> n2 = r2.next();
        assertEquals(n1, n2);
        count++;
    }

    assertEquals(150-skipLines, count);
}
 
开发者ID:deeplearning4j,项目名称:DataVec,代码行数:27,代码来源:TestSerialization.java

示例4: testCsvRecordReader

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testCsvRecordReader() throws Exception {
    SerializerInstance si = sc.env().serializer().newInstance();
    assertTrue(si instanceof KryoSerializerInstance);

    RecordReader r1 = new CSVRecordReader(1,'\t');
    RecordReader r2 = serDe(r1, si);

    File f = new ClassPathResource("iris_tab_delim.txt").getFile();
    r1.initialize(new FileSplit(f));
    r2.initialize(new FileSplit(f));

    while(r1.hasNext()){
        assertEquals(r1.next(), r2.next());
    }
    assertFalse(r2.hasNext());
}
 
开发者ID:deeplearning4j,项目名称:DataVec,代码行数:18,代码来源:TestKryoSerialization.java

示例5: testReadingJson

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testReadingJson() throws Exception {
    //Load 3 values from 3 JSON files
    //stricture: a:value, b:value, c:x:value, c:y:value
    //And we want to load only a:value, b:value and c:x:value
    //For first JSON file: all values are present
    //For second JSON file: b:value is missing
    //For third JSON file: c:x:value is missing

    ClassPathResource cpr = new ClassPathResource("json/json_test_0.txt");
    String path = cpr.getFile().getAbsolutePath().replace("0", "%d");

    InputSplit is = new NumberedFileInputSplit(path, 0, 2);

    RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
    rr.initialize(is);

    testJacksonRecordReader(rr);
}
 
开发者ID:deeplearning4j,项目名称:DataVec,代码行数:20,代码来源:JacksonRecordReaderTest.java

示例6: testLRN

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testLRN() throws Exception {
    List<String> labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu"));
    String rootDir = new ClassPathResource("lfwtest").getFile().getAbsolutePath();

    RecordReader reader = new ImageRecordReader(28, 28, 3);
    reader.initialize(new FileSplit(new File(rootDir)));
    DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size());
    labels.remove("lfwtest");
    NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN();
    builder.setInputType(InputType.convolutional(28, 28, 3));

    MultiLayerConfiguration conf = builder.build();

    ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer();
    assertEquals(6, layer2.getNIn());

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ConvolutionLayerSetupTest.java

示例7: testRRDSIwithAsync

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testRRDSIwithAsync() throws Exception {
    RecordReader csv = new CSVRecordReader();
    csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));

    int batchSize = 10;
    int labelIdx = 4;
    int numClasses = 3;

    RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses);
    AsyncDataSetIterator adsi = new AsyncDataSetIterator(rrdsi, 8, true);
    while (adsi.hasNext()) {
        DataSet ds = adsi.next();

    }

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:RecordReaderDataSetiteratorTest.java

示例8: testNormalizerPrefetchReset

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testNormalizerPrefetchReset() throws Exception {
    //Check NPE fix for: https://github.com/deeplearning4j/deeplearning4j/issues/4214
    RecordReader csv = new CSVRecordReader();
    csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));

    int batchSize = 3;

    DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, 4, 4, true);

    DataNormalization normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(iter);
    iter.setPreProcessor(normalizer);

    iter.inputColumns();    //Prefetch
    iter.totalOutcomes();
    iter.hasNext();
    iter.reset();
    iter.next();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:RecordReaderDataSetiteratorTest.java

示例9: testsBasicMeta

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testsBasicMeta() throws Exception {
    //As per testBasic - but also loading metadata
    RecordReader rr2 = new CSVRecordReader(0, ',');
    rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));

    RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
                    .addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();

    rrmdsi.setCollectMetaData(true);

    int count = 0;
    while (rrmdsi.hasNext()) {
        MultiDataSet mds = rrmdsi.next();
        MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
        assertEquals(mds, fromMeta);
        count++;
    }
    assertEquals(150 / 10, count);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:RecordReaderMultiDataSetIteratorTest.java

示例10: testSplittingCSVMeta

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testSplittingCSVMeta() throws Exception {
    //Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
    //Inputs: columns 0 and 1-2
    //Outputs: columns 3, and 4->OneHot
    RecordReader rr2 = new CSVRecordReader(0, ',');
    rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));

    RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
                    .addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2)
                    .addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
    rrmdsi.setCollectMetaData(true);

    int count = 0;
    while (rrmdsi.hasNext()) {
        MultiDataSet mds = rrmdsi.next();
        MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
        assertEquals(mds, fromMeta);
        count++;
    }
    assertEquals(150 / 10, count);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:RecordReaderMultiDataSetIteratorTest.java

示例11: testNextAndReset

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testNextAndReset() throws Exception {
    int epochs = 3;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);

    assertTrue(multiIter.hasNext());
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertFalse(path == null);
    }
    assertEquals(epochs, multiIter.epochs);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:MultipleEpochsIteratorTest.java

示例12: testLoadFullDataSet

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testLoadFullDataSet() throws Exception {
    int epochs = 3;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
    DataSet ds = iter.next(50);

    assertEquals(50, ds.getFeatures().size(0));

    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);

    assertTrue(multiIter.hasNext());
    int count = 0;
    while (multiIter.hasNext()) {
        DataSet path = multiIter.next();
        assertNotNull(path);
        assertEquals(50, path.numExamples(), 0);
        count++;
    }
    assertEquals(epochs, count);
    assertEquals(epochs, multiIter.epochs);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:25,代码来源:MultipleEpochsIteratorTest.java

示例13: testLoadBatchDataSet

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testLoadBatchDataSet() throws Exception {
    int epochs = 2;

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
    DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3);
    DataSet ds = iter.next(20);
    assertEquals(20, ds.getFeatures().size(0));
    MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);

    while (multiIter.hasNext()) {
        DataSet path = multiIter.next(10);
        assertNotNull(path);
        assertEquals(path.numExamples(), 10, 0.0);
    }

    assertEquals(epochs, multiIter.epochs);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:MultipleEpochsIteratorTest.java

示例14: readCSVDataset

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
private static DataSetIterator readCSVDataset(String csvFileClasspath, int BATCH_SIZE, int LABEL_INDEX, int numClasses)
        throws IOException, InterruptedException {

    RecordReader rr = new CSVRecordReader();
    rr.initialize(new FileSplit(new File(csvFileClasspath)));
    DataSetIterator iterator = new RecordReaderDataSetIterator(rr, BATCH_SIZE, LABEL_INDEX, numClasses);

    return iterator;
}
 
开发者ID:emara-geek,项目名称:arabic-characters-recognition,代码行数:10,代码来源:ModelGenerator.java

示例15: getRecordReader

import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
public RecordReader getRecordReader(int batchSize, int numExamples, int[] imgDim, int numLabels,
                PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, Random rng) {
    load(batchSize, numExamples, numLabels, labelGenerator, splitTrainTest, rng);
    RecordReader recordReader =
                    new ImageRecordReader(imgDim[0], imgDim[1], imgDim[2], labelGenerator, imageTransform);

    try {
        InputSplit data = train ? inputSplit[0] : inputSplit[1];
        recordReader.initialize(data);
    } catch (IOException | InterruptedException e) {
        e.printStackTrace();
    }
    return recordReader;
}
 
开发者ID:deeplearning4j,项目名称:DataVec,代码行数:15,代码来源:LFWLoader.java


注:本文中的org.datavec.api.records.reader.RecordReader.initialize方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。