本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator类的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator类的具体用法?Java DataSetIterator怎么用?Java DataSetIterator使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
DataSetIterator类属于org.nd4j.linalg.dataset.api.iterator包,在下文中一共展示了DataSetIterator类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: createDataSource
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例2: createDataSource
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 11;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
DataSet allData = iterator.next();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例3: getDataSetIterator
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
public static DataSetIterator getDataSetIterator(String DATA_PATH, boolean isTraining, WordVectors wordVectors, int minibatchSize,
int maxSentenceLength, Random rng ){
String path = FilenameUtils.concat(DATA_PATH, (isTraining ? "aclImdb/train/" : "aclImdb/test/"));
String positiveBaseDir = FilenameUtils.concat(path, "pos");
String negativeBaseDir = FilenameUtils.concat(path, "neg");
File filePositive = new File(positiveBaseDir);
File fileNegative = new File(negativeBaseDir);
Map<String,List<File>> reviewFilesMap = new HashMap<>();
reviewFilesMap.put("Positive", Arrays.asList(filePositive.listFiles()));
reviewFilesMap.put("Negative", Arrays.asList(fileNegative.listFiles()));
LabeledSentenceProvider sentenceProvider = new FileLabeledSentenceProvider(reviewFilesMap, rng);
return new CnnSentenceDataSetIterator.Builder()
.sentenceProvider(sentenceProvider)
.wordVectors(wordVectors)
.minibatchSize(minibatchSize)
.maxSentenceLength(maxSentenceLength)
.useNormalizedWordVectors(false)
.build();
}
示例4: getTrainingData
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
private static DataSetIterator getTrainingData(int batchSize, Random rand){
double [] sum = new double[nSamples];
double [] input1 = new double[nSamples];
double [] input2 = new double[nSamples];
for (int i= 0; i< nSamples; i++) {
int MIN_RANGE = 0;
int MAX_RANGE = 3;
input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
sum[i] = input1[i] + input2[i];
}
INDArray inputNDArray1 = Nd4j.create(input1, new int[]{nSamples,1});
INDArray inputNDArray2 = Nd4j.create(input2, new int[]{nSamples,1});
INDArray inputNDArray = Nd4j.hstack(inputNDArray1,inputNDArray2);
INDArray outPut = Nd4j.create(sum, new int[]{nSamples, 1});
DataSet dataSet = new DataSet(inputNDArray, outPut);
List<DataSet> listDs = dataSet.asList();
Collections.shuffle(listDs,rng);
return new ListDataSetIterator(listDs,batchSize);
}
示例5: evalMnistTestSet
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
private static void evalMnistTestSet(MultiLayerNetwork leNetModel) throws Exception {
log.info("Load test data....");
int batchSize = 64;
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
log.info("Evaluate model....");
int outputNum = 10;
Evaluation eval = new Evaluation(outputNum);
while(mnistTest.hasNext()){
DataSet dataSet = mnistTest.next();
INDArray output = leNetModel.output(dataSet.getFeatureMatrix(), false);
eval.eval(dataSet.getLabels(), output);
}
log.info(eval.stats());
}
示例6: getDataSetIterator
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/**
* Returns the actual iterator.
*
* @param data the dataset to use
* @param seed the seed for the random number generator
* @param batchSize the batch size to use
* @return the DataSetIterator
*/
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
throws InvalidInputDataException, IOException {
validate(data);
initWordVectors();
final LabeledSentenceProvider prov = getSentenceProvider(data);
return new RnnTextEmbeddingDataSetIterator(
data,
wordVectors,
tokenizerFactory,
tokenPreProcess,
stopwords,
prov,
batchSize,
truncateLength);
}
示例7: getDataSetIterator
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
throws InvalidInputDataException, IOException {
validate(data);
initWordVectors();
final LabeledSentenceProvider sentenceProvider = getSentenceProvider(data);
return new RnnTextEmbeddingDataSetIterator(
data,
wordVectors,
tokenizerFactory,
tokenPreProcess,
stopwords,
sentenceProvider,
batchSize,
truncateLength);
}
示例8: getDataSetIterator
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/**
* This method returns the iterator. Scales all intensity values: it divides them by 255.
*
* @param data the dataset to use
* @param seed the seed for the random number generator
* @param batchSize the batch size to use
* @return the iterator
* @throws Exception
*/
@Override
public DataSetIterator getDataSetIterator(Instances data, int seed, int batchSize)
throws Exception {
batchSize = Math.min(data.numInstances(), batchSize);
validate(data);
ImageRecordReader reader = getImageRecordReader(data);
final int labelIndex = 1; // Use explicit label index position
final int numPossibleLabels = data.numClasses();
DataSetIterator tmpIter =
new RecordReaderDataSetIterator(reader, batchSize, labelIndex, numPossibleLabels);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.fit(tmpIter);
tmpIter.setPreProcessor(scaler);
return tmpIter;
}
示例9: initEarlyStopping
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/**
* Initialize early stopping with the given data
*
* @param data Data
* @return Augmented data - if early stopping applies, return train set without validation set
* @throws Exception
*/
protected Instances initEarlyStopping(Instances data) throws Exception {
// Split train/validation
double valSplit = earlyStopping.getValidationSetPercentage();
Instances trainData;
Instances valData;
if (useEarlyStopping()) {
// Split in train and validation
Instances[] insts = splitTrainVal(data, valSplit);
trainData = insts[0];
valData = insts[1];
validateSplit(trainData, valData);
DataSetIterator valIterator = getDataSetIterator(valData, cacheMode, "val");
earlyStopping.init(valIterator);
} else {
// Keep the full data
trainData = data;
}
return trainData;
}
示例10: testGetIteratorNominalClass
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNominalClass() throws Exception {
final Instances data = DatasetLoader.loadAngerMetaClassification();
final int batchSize = 1;
final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);
Set<Integer> labels = new HashSet<>();
for (int i = 0; i < data.size(); i++) {
Instance inst = data.get(i);
int label = Integer.parseInt(inst.stringValue(data.classIndex()));
final DataSet next = it.next();
int itLabel = next.getLabels().argMax().getInt(0);
Assert.assertEquals(label, itLabel);
labels.add(label);
}
final Set<Integer> collect =
it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
Assert.assertEquals(2, labels.size());
Assert.assertTrue(labels.containsAll(collect));
Assert.assertTrue(collect.containsAll(labels));
}
示例11: testGetIteratorNumericClass
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNumericClass() throws Exception {
final Instances data = DatasetLoader.loadAngerMeta();
final int batchSize = 1;
final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);
Set<Double> labels = new HashSet<>();
for (int i = 0; i < data.size(); i++) {
Instance inst = data.get(i);
double label = inst.value(data.classIndex());
final DataSet next = it.next();
double itLabel = next.getLabels().getDouble(0);
Assert.assertEquals(label, itLabel, 1e-5);
labels.add(label);
}
}
示例12: testGetIteratorNominalClass
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNominalClass() throws Exception {
final Instances data = DatasetLoader.loadReutersMinimal();
final int batchSize = 1;
final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);
Set<Integer> labels = new HashSet<>();
for (Instance inst : data) {
int label = Integer.parseInt(inst.stringValue(data.classIndex()));
final DataSet next = it.next();
int itLabel = next.getLabels().argMax().getInt(0);
Assert.assertEquals(label, itLabel);
labels.add(label);
}
final Set<Integer> collect =
it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
Assert.assertEquals(2, labels.size());
Assert.assertTrue(labels.containsAll(collect));
Assert.assertTrue(collect.containsAll(labels));
}
示例13: testGetIteratorNumericClass
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNumericClass() throws Exception {
final Instances data = makeData();
final int batchSize = 1;
final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);
Set<Double> labels = new HashSet<>();
for (int i = 0; i < data.size(); i++) {
Instance inst = data.get(i);
double label = inst.value(data.classIndex());
final DataSet next = it.next();
double itLabel = next.getLabels().getDouble(0);
Assert.assertEquals(label, itLabel, 1e-5);
labels.add(label);
}
}
示例14: testGetIterator
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIterator() throws Exception {
final Instances metaData = DatasetLoader.loadMiniMnistMeta();
this.idi.setImagesLocation(new File("datasets/nominal/mnist-minimal"));
final int batchSize = 1;
final DataSetIterator it = this.idi.getDataSetIterator(metaData, SEED, batchSize);
Set<Integer> labels = new HashSet<>();
for (Instance inst : metaData) {
int label = Integer.parseInt(inst.stringValue(1));
final DataSet next = it.next();
int itLabel = next.getLabels().argMax().getInt(0);
Assert.assertEquals(label, itLabel);
labels.add(label);
}
final List<Integer> collect =
it.getLabels().stream().map(Integer::valueOf).collect(Collectors.toList());
Assert.assertEquals(10, labels.size());
Assert.assertTrue(labels.containsAll(collect));
Assert.assertTrue(collect.containsAll(labels));
}
示例15: testOutputFormat
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入依赖的package包/类
@Test
public void testOutputFormat() throws Exception {
for (int tl : Arrays.asList(10, 50, 200)) {
rii.setTruncateLength(tl);
for (int bs : Arrays.asList(1, 4, 8, 16)) {
final DataSetIterator it = rii.getDataSetIterator(data, TestUtil.SEED, bs);
assertEquals(bs, it.batch());
assertEquals(Arrays.asList("0.0", "1.0"), it.getLabels());
final DataSet next = it.next();
// Check feature shape, expect: (batchsize x wordvecsize x sequencelength)
final int[] shapeFeats = next.getFeatures().shape();
final int[] expShapeFeats = {bs, 6, tl};
assertEquals(expShapeFeats[0],shapeFeats[0]);
assertEquals(expShapeFeats[1],shapeFeats[1]);
assertTrue(expShapeFeats[2] >= shapeFeats[2]);
// Check label shape, expect: (batchsize x numclasses x sequencelength)
final int[] shapeLabels = next.getLabels().shape();
final int[] expShapeLabels = {bs, data.numClasses(), tl};
assertEquals(expShapeLabels[0], shapeLabels[0]);
assertEquals(expShapeLabels[1], shapeLabels[1]);
assertTrue(expShapeLabels[2] >= shapeLabels[2]);
}
}
}