本文整理汇总了Java中org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator类的典型用法代码示例。如果您正苦于以下问题:Java RecordReaderDataSetIterator类的具体用法?Java RecordReaderDataSetIterator怎么用?Java RecordReaderDataSetIterator使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
RecordReaderDataSetIterator类属于org.deeplearning4j.datasets.datavec包,在下文中一共展示了RecordReaderDataSetIterator类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: createDataSource
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例2: createDataSource
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 11;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
DataSet allData = iterator.next();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例3: getDataSetIterator
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
/**
* This method returns the iterator. Scales all intensity values: it divides them by 255.
*
* @param data the dataset to use
* @param seed the seed for the random number generator
* @param batchSize the batch size to use
* @return the iterator
* @throws Exception
*/
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
throws Exception {
batchSize = Math.min(data.numInstances(), batchSize);
validate(data);
ImageRecordReader reader = getImageRecordReader(data);
final int labelIndex = 1; // Use explicit label index position
final int numPossibleLabels = data.numClasses();
DataSetIterator tmpIter =
new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(tmpIter);
tmpIter.setPreProcessor(scaler);
return tmpIter;
}
示例4: testLRN
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testLRN() throws Exception {
List<String> labels = new ArrayList<>(Arrays.asList("Zico", "Ziwang_Xu"));
String rootDir = new ClassPathResource("lfwtest").getFile().getAbsolutePath();
RecordReader reader = new ImageRecordReader(28, 28, 3);
reader.initialize(new FileSplit(new File(rootDir)));
DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size());
labels.remove("lfwtest");
NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN();
builder.setInputType(InputType.convolutional(28, 28, 3));
MultiLayerConfiguration conf = builder.build();
ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer();
assertEquals(6, layer2.getNIn());
}
示例5: testNextAndReset
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testNextAndReset() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
assertTrue(multiIter.hasNext());
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
assertFalse(path == null);
}
assertEquals(epochs, multiIter.epochs);
}
示例6: testLoadFullDataSet
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testLoadFullDataSet() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
DataSet ds = iter.next(50);
assertEquals(50, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
assertTrue(multiIter.hasNext());
int count = 0;
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
assertNotNull(path);
assertEquals(50, path.numExamples(), 0);
count++;
}
assertEquals(epochs, count);
assertEquals(epochs, multiIter.epochs);
}
示例7: testLoadBatchDataSet
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testLoadBatchDataSet() throws Exception {
int epochs = 2;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3);
DataSet ds = iter.next(20);
assertEquals(20, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
while (multiIter.hasNext()) {
DataSet path = multiIter.next(10);
assertNotNull(path);
assertEquals(path.numExamples(), 10, 0.0);
}
assertEquals(epochs, multiIter.epochs);
}
示例8: testMnist
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Test
public void testMnist() throws Exception {
ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);
MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);
while (dsi.hasNext()) {
DataSet dsExp = dsi.next();
DataSet dsAct = iter.next();
INDArray fExp = dsExp.getFeatureMatrix();
fExp.divi(255);
INDArray lExp = dsExp.getLabels();
INDArray fAct = dsAct.getFeatureMatrix();
INDArray lAct = dsAct.getLabels();
assertEquals(fExp, fAct);
assertEquals(lExp, lAct);
}
assertFalse(iter.hasNext());
}
示例9: processBatchIfRequired
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private void processBatchIfRequired(List<List<Writable>> list, boolean finalRecord) throws Exception {
if (list.isEmpty())
return;
if (list.size() < batchSize && !finalRecord)
return;
RecordReader rr = new CollectionRecordReader(list);
RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(rr, new SelfWritableConverter(), batchSize,
labelIndex, numPossibleLabels, regression);
DataSet ds = iter.next();
String filename = "dataset_" + uid + "_" + (outputCount++) + ".bin";
URI uri = new URI(outputDir.getPath() + "/" + filename);
FileSystem file = FileSystem.get(uri, conf);
try (FSDataOutputStream out = file.create(new Path(uri))) {
ds.save(out);
}
list.clear();
}
示例10: doPredict
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
@Override
protected Object doPredict(List<String> line) {
try {
ListStringSplit input = new ListStringSplit(Collections.singletonList(line));
ListStringRecordReader rr = new ListStringRecordReader();
rr.initialize(input);
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1);
DataSet ds = iterator.next();
INDArray prediction = model.output(ds.getFeatures());
DataType outputType = types.get(this.output);
switch (outputType) {
case _float : return prediction.getDouble(0);
case _class: {
int numClasses = 2;
double max = 0;
int maxIndex = -1;
for (int i=0;i<numClasses;i++) {
if (prediction.getDouble(i) > max) {maxIndex = i; max = prediction.getDouble(i);}
}
return maxIndex;
// return prediction.getInt(0,1); // numberOfClasses
}
default: throw new IllegalArgumentException("Output type not yet supported "+outputType);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
示例11: check
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private void check(BufferedImage image) throws Exception
{
ImageIO.write(image, "png", new File("tmp.png")); //saves the image to the tmp.png file
ImageRecordReader reader = new ImageRecordReader(150, 150, 3);
reader.initialize(new FileSplit(new File("tmp.png"))); //reads the tmp.png file
DataSetIterator dataIter = new RecordReaderDataSetIterator(reader, 1);
while (dataIter.hasNext())
{
//Normalize the data from the file
DataNormalization normalization = new NormalizerMinMaxScaler();
DataSet set = dataIter.next();
normalization.fit(set);
normalization.transform(set);
INDArray array = MainGUI.model.output(set.getFeatures(), false); //send the data to the model and get the results
//Process the results and print them in an understandable format (percentage scores)
String txt = "";
DecimalFormat df = new DecimalFormat("#.00");
for (int i = 0; i < array.length(); i++)
{
txt += MainGUI.labels.get(i) + ": " + (array.getDouble(i)*100 < 1 ? "0" : "") + df.format((array.getDouble(i)*100)) + "%\n";
}
probabilityArea.setText(txt);
}
reader.close();
}
示例12: createInternal
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private DataSetIterator createInternal(InputSplit inputSplit) throws IOException {
ImageTransform imageTransform = imageTransformFactory.create();
int width = imageTransformConfigurationResource.getScaledWidth();
int height = imageTransformConfigurationResource.getScaledHeight();
int channels = imageTransformConfigurationResource.getChannels();
int batchSize = networkConfigurationResource.getBatchSize();
int outputs = networkConfigurationResource.getOutputs();
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, pathLabelGenerator);
recordReader.initialize(inputSplit, imageTransform);
RecordReaderDataSetIterator recordReaderDataSetIterator = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputs);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(recordReaderDataSetIterator);
recordReaderDataSetIterator.setPreProcessor(scaler);
return recordReaderDataSetIterator;
}
示例13: readCSVDataset
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
private static DataSetIterator readCSVDataset(String csvFileClasspath, int BATCH_SIZE, int LABEL_INDEX, int numClasses)
throws IOException, InterruptedException {
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File(csvFileClasspath)));
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, BATCH_SIZE, LABEL_INDEX, numClasses);
return iterator;
}
示例14: irisCsv
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
static DataIterator<NormalizerStandardize> irisCsv(String name) {
CSVRecordReader recordReader = new CSVRecordReader(0, ",");
try {
recordReader.initialize(new FileSplit(new File(name)));
} catch (Exception e) {
e.printStackTrace();
}
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 50; //Iris data set: 150 examples total.
RecordReaderDataSetIterator iterator = new RecordReaderDataSetIterator(
recordReader,
batchSize,
labelIndex,
numClasses
);
NormalizerStandardize normalizer = new NormalizerStandardize();
while (iterator.hasNext()) {
normalizer.fit(iterator.next());
}
iterator.reset();
iterator.setPreProcessor(normalizer);
return new DataIterator<>(iterator, normalizer);
}
示例15: main
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; //导入依赖的package包/类
public static void main(String... args) throws Exception {
Options options = new Options();
options.addOption("i", "input", true, "The file with test data.");
options.addOption("m", "model", true, "Name of trained model file.");
CommandLine cmd = new BasicParser().parse(options, args);
String input = cmd.getOptionValue("i");
String modelName = cmd.getOptionValue("m");
if (cmd.hasOption("i") && cmd.hasOption("m")) {
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(modelName);
DataIterator<NormalizerStandardize> it = DataIterator.irisCsv(input);
RecordReaderDataSetIterator testData = it.getIterator();
NormalizerStandardize normalizer = it.getNormalizer();
normalizer.load(
new File(modelName + ".norm1"),
new File(modelName + ".norm2"),
new File(modelName + ".norm3"),
new File(modelName + ".norm4")
);
Evaluation eval = new Evaluation(3);
while (testData.hasNext()) {
DataSet ds = testData.next();
INDArray output = model.output(ds.getFeatureMatrix());
eval.eval(ds.getLabels(), output);
}
log.info(eval.stats());
} else {
log.error("Invalid arguments.");
new HelpFormatter().printHelp("Evaluate", options);
}
}