本文整理汇总了Java中org.datavec.api.records.reader.RecordReader.initialize方法的典型用法代码示例。如果您正苦于以下问题:Java RecordReader.initialize方法的具体用法?Java RecordReader.initialize怎么用?Java RecordReader.initialize使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.datavec.api.records.reader.RecordReader
的用法示例。
在下文中一共展示了RecordReader.initialize方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: createDataSource
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例2: createDataSource
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 11;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
DataSet allData = iterator.next();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例3: testCsvRRSerializationResults
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testCsvRRSerializationResults() throws Exception {
int skipLines = 3;
RecordReader r1 = new CSVRecordReader(skipLines, '\t');
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream os = new ObjectOutputStream(baos);
os.writeObject(r1);
byte[] bytes = baos.toByteArray();
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes));
RecordReader r2 = (RecordReader) ois.readObject();
File f = new ClassPathResource("iris_tab_delim.txt").getFile();
r1.initialize(new FileSplit(f));
r2.initialize(new FileSplit(f));
int count = 0;
while(r1.hasNext()){
List<Writable> n1 = r1.next();
List<Writable> n2 = r2.next();
assertEquals(n1, n2);
count++;
}
assertEquals(150-skipLines, count);
}
示例4: testCsvRecordReader
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testCsvRecordReader() throws Exception {
SerializerInstance si = sc.env().serializer().newInstance();
assertTrue(si instanceof KryoSerializerInstance);
RecordReader r1 = new CSVRecordReader(1,'\t');
RecordReader r2 = serDe(r1, si);
File f = new ClassPathResource("iris_tab_delim.txt").getFile();
r1.initialize(new FileSplit(f));
r2.initialize(new FileSplit(f));
while(r1.hasNext()){
assertEquals(r1.next(), r2.next());
}
assertFalse(r2.hasNext());
}
示例5: testReadingJson
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testReadingJson() throws Exception {
//Load 3 values from 3 JSON files
//stricture: a:value, b:value, c:x:value, c:y:value
//And we want to load only a:value, b:value and c:x:value
//For first JSON file: all values are present
//For second JSON file: b:value is missing
//For third JSON file: c:x:value is missing
ClassPathResource cpr = new ClassPathResource("json/json_test_0.txt");
String path = cpr.getFile().getAbsolutePath().replace("0", "%d");
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(is);
testJacksonRecordReader(rr);
}
示例6: testLRN
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testLRN() throws Exception {
List<String> labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu"));
String rootDir = new ClassPathResource("lfwtest").getFile().getAbsolutePath();
RecordReader reader = new ImageRecordReader(28, 28, 3);
reader.initialize(new FileSplit(new File(rootDir)));
DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size());
labels.remove("lfwtest");
NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN();
builder.setInputType(InputType.convolutional(28, 28, 3));
MultiLayerConfiguration conf = builder.build();
ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer();
assertEquals(6, layer2.getNIn());
}
示例7: testRRDSIwithAsync
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testRRDSIwithAsync() throws Exception {
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
int batchSize = 10;
int labelIdx = 4;
int numClasses = 3;
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses);
AsyncDataSetIterator adsi = new AsyncDataSetIterator(rrdsi, 8, true);
while (adsi.hasNext()) {
DataSet ds = adsi.next();
}
}
示例8: testNormalizerPrefetchReset
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testNormalizerPrefetchReset() throws Exception {
//Check NPE fix for: https://github.com/deeplearning4j/deeplearning4j/issues/4214
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
int batchSize = 3;
DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, 4, 4, true);
DataNormalization normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fit(iter);
iter.setPreProcessor(normalizer);
iter.inputColumns(); //Prefetch
iter.totalOutcomes();
iter.hasNext();
iter.reset();
iter.next();
}
示例9: testsBasicMeta
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testsBasicMeta() throws Exception {
//As per testBasic - but also loading metadata
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
.addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true);
int count = 0;
while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next();
MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
assertEquals(mds, fromMeta);
count++;
}
assertEquals(150 / 10, count);
}
示例10: testSplittingCSVMeta
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testSplittingCSVMeta() throws Exception {
//Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
//Inputs: columns 0 and 1-2
//Outputs: columns 3, and 4->OneHot
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10)
.addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2)
.addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true);
int count = 0;
while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next();
MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
assertEquals(mds, fromMeta);
count++;
}
assertEquals(150 / 10, count);
}
示例11: testNextAndReset
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testNextAndReset() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
assertTrue(multiIter.hasNext());
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
assertFalse(path == null);
}
assertEquals(epochs, multiIter.epochs);
}
示例12: testLoadFullDataSet
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testLoadFullDataSet() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
DataSet ds = iter.next(50);
assertEquals(50, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
assertTrue(multiIter.hasNext());
int count = 0;
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
assertNotNull(path);
assertEquals(50, path.numExamples(), 0);
count++;
}
assertEquals(epochs, count);
assertEquals(epochs, multiIter.epochs);
}
示例13: testLoadBatchDataSet
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
@Test
public void testLoadBatchDataSet() throws Exception {
int epochs = 2;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3);
DataSet ds = iter.next(20);
assertEquals(20, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
while (multiIter.hasNext()) {
DataSet path = multiIter.next(10);
assertNotNull(path);
assertEquals(path.numExamples(), 10, 0.0);
}
assertEquals(epochs, multiIter.epochs);
}
示例14: readCSVDataset
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
private static DataSetIterator readCSVDataset(String csvFileClasspath, int BATCH_SIZE, int LABEL_INDEX, int numClasses)
throws IOException, InterruptedException {
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File(csvFileClasspath)));
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, BATCH_SIZE, LABEL_INDEX, numClasses);
return iterator;
}
示例15: getRecordReader
import org.datavec.api.records.reader.RecordReader; //导入方法依赖的package包/类
public RecordReader getRecordReader(int batchSize, int numExamples, int[] imgDim, int numLabels,
PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, Random rng) {
load(batchSize, numExamples, numLabels, labelGenerator, splitTrainTest, rng);
RecordReader recordReader =
new ImageRecordReader(imgDim[0], imgDim[1], imgDim[2], labelGenerator, imageTransform);
try {
InputSplit data = train ? inputSplit[0] : inputSplit[1];
recordReader.initialize(data);
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
return recordReader;
}