本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.next方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.next方法的具体用法?Java DataSetIterator.next怎么用?Java DataSetIterator.next使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.dataset.api.iterator.DataSetIterator
的用法示例。
在下文中一共展示了DataSetIterator.next方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: createDataSource
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例2: createDataSource
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 11;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
DataSet allData = iterator.next();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例3: evalMnistTestSet
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private static void evalMnistTestSet(MultiLayerNetwork leNetModel) throws Exception {
log.info("Load test data....");
int batchSize = 64;
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
log.info("Evaluate model....");
int outputNum = 10;
Evaluation eval = new Evaluation(outputNum);
while(mnistTest.hasNext()){
DataSet dataSet = mnistTest.next();
INDArray output = leNetModel.output(dataSet.getFeatureMatrix(), false);
eval.eval(dataSet.getLabels(), output);
}
log.info(eval.stats());
}
示例4: testGetIteratorNominalClass
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNominalClass() throws Exception {
final Instances data = DatasetLoader.loadAngerMetaClassification();
final int batchSize = 1;
final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);
Set<Integer> labels = new HashSet<>();
for (int i = 0; i < data.size(); i++) {
Instance inst = data.get(i);
int label = Integer.parseInt(inst.stringValue(data.classIndex()));
final DataSet next = it.next();
int itLabel = next.getLabels().argMax().getInt(0);
Assert.assertEquals(label, itLabel);
labels.add(label);
}
final Set<Integer> collect =
it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
Assert.assertEquals(2, labels.size());
Assert.assertTrue(labels.containsAll(collect));
Assert.assertTrue(collect.containsAll(labels));
}
示例5: testGetIteratorNumericClass
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNumericClass() throws Exception {
final Instances data = DatasetLoader.loadAngerMeta();
final int batchSize = 1;
final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);
Set<Double> labels = new HashSet<>();
for (int i = 0; i < data.size(); i++) {
Instance inst = data.get(i);
double label = inst.value(data.classIndex());
final DataSet next = it.next();
double itLabel = next.getLabels().getDouble(0);
Assert.assertEquals(label, itLabel, 1e-5);
labels.add(label);
}
}
示例6: testGetIteratorNominalClass
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNominalClass() throws Exception {
final Instances data = DatasetLoader.loadReutersMinimal();
final int batchSize = 1;
final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);
Set<Integer> labels = new HashSet<>();
for (Instance inst : data) {
int label = Integer.parseInt(inst.stringValue(data.classIndex()));
final DataSet next = it.next();
int itLabel = next.getLabels().argMax().getInt(0);
Assert.assertEquals(label, itLabel);
labels.add(label);
}
final Set<Integer> collect =
it.getLabels().stream().map(s -> Double.valueOf(s).intValue()).collect(Collectors.toSet());
Assert.assertEquals(2, labels.size());
Assert.assertTrue(labels.containsAll(collect));
Assert.assertTrue(collect.containsAll(labels));
}
示例7: testGetIteratorNumericClass
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIteratorNumericClass() throws Exception {
final Instances data = makeData();
final int batchSize = 1;
final DataSetIterator it = this.cteii.getDataSetIterator(data, SEED, batchSize);
Set<Double> labels = new HashSet<>();
for (int i = 0; i < data.size(); i++) {
Instance inst = data.get(i);
double label = inst.value(data.classIndex());
final DataSet next = it.next();
double itLabel = next.getLabels().getDouble(0);
Assert.assertEquals(label, itLabel, 1e-5);
labels.add(label);
}
}
示例8: testGetIterator
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/** Test getDataSetIterator */
@Test
public void testGetIterator() throws Exception {
final Instances metaData = DatasetLoader.loadMiniMnistMeta();
this.idi.setImagesLocation(new File("datasets/nominal/mnist-minimal"));
final int batchSize = 1;
final DataSetIterator it = this.idi.getDataSetIterator(metaData, SEED, batchSize);
Set<Integer> labels = new HashSet<>();
for (Instance inst : metaData) {
int label = Integer.parseInt(inst.stringValue(1));
final DataSet next = it.next();
int itLabel = next.getLabels().argMax().getInt(0);
Assert.assertEquals(label, itLabel);
labels.add(label);
}
final List<Integer> collect =
it.getLabels().stream().map(Integer::valueOf).collect(Collectors.toList());
Assert.assertEquals(10, labels.size());
Assert.assertTrue(labels.containsAll(collect));
Assert.assertTrue(collect.containsAll(labels));
}
示例9: testOutputFormat
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testOutputFormat() throws Exception {
for (int tl : Arrays.asList(10, 50, 200)) {
rii.setTruncateLength(tl);
for (int bs : Arrays.asList(1, 4, 8, 16)) {
final DataSetIterator it = rii.getDataSetIterator(data, TestUtil.SEED, bs);
assertEquals(bs, it.batch());
assertEquals(Arrays.asList("0.0", "1.0"), it.getLabels());
final DataSet next = it.next();
// Check feature shape, expect: (batchsize x wordvecsize x sequencelength)
final int[] shapeFeats = next.getFeatures().shape();
final int[] expShapeFeats = {bs, 6, tl};
assertEquals(expShapeFeats[0],shapeFeats[0]);
assertEquals(expShapeFeats[1],shapeFeats[1]);
assertTrue(expShapeFeats[2] >= shapeFeats[2]);
// Check label shape, expect: (batchsize x numclasses x sequencelength)
final int[] shapeLabels = next.getLabels().shape();
final int[] expShapeLabels = {bs, data.numClasses(), tl};
assertEquals(expShapeLabels[0], shapeLabels[0]);
assertEquals(expShapeLabels[1], shapeLabels[1]);
assertTrue(expShapeLabels[2] >= shapeLabels[2]);
}
}
}
示例10: main
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* args[0] input: word2vecファイル名
* args[1] input: sentimentモデル名
* args[2] input: test親フォルダ名
*
* @param args
* @throws Exception
*/
public static void main (final String[] args) throws Exception {
if (args[0]==null || args[1]==null || args[2]==null)
System.exit(1);
WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);
DataSetIterator test = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
Evaluation evaluation = new Evaluation();
while(test.hasNext()) {
DataSet t = test.next();
INDArray features = t.getFeatures();
INDArray lables = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = model.output(features,false,inMask,outMask);
evaluation.evalTimeSeries(lables,predicted,outMask);
}
System.out.println(evaluation.stats());
}
示例11: assertCachingDataSetIteratorHasAllTheData
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void assertCachingDataSetIteratorHasAllTheData(int rows, int inputColumns, int outputColumns,
DataSet dataSet, DataSetIterator it, CachingDataSetIterator cachedIt) {
cachedIt.reset();
it.reset();
dataSet.setFeatures(Nd4j.zeros(rows, inputColumns));
dataSet.setLabels(Nd4j.ones(rows, outputColumns));
while (it.hasNext()) {
assertTrue(cachedIt.hasNext());
DataSet cachedDs = cachedIt.next();
assertEquals(1000.0, cachedDs.getFeatureMatrix().sumNumber());
assertEquals(0.0, cachedDs.getLabels().sumNumber());
DataSet ds = it.next();
assertEquals(0.0, ds.getFeatureMatrix().sumNumber());
assertEquals(20.0, ds.getLabels().sumNumber());
}
assertFalse(cachedIt.hasNext());
assertFalse(it.hasNext());
}
示例12: testItervsDataset
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public float testItervsDataset(DataNormalization preProcessor) {
DataSet dataCopy = data.copy();
DataSetIterator dataIter = new TestDataSetIterator(dataCopy, batchSize);
preProcessor.fit(dataCopy);
preProcessor.transform(dataCopy);
INDArray transformA = dataCopy.getFeatures();
preProcessor.fit(dataIter);
dataIter.setPreProcessor(preProcessor);
DataSet next = dataIter.next();
INDArray transformB = next.getFeatures();
while (dataIter.hasNext()) {
next = dataIter.next();
INDArray transformb = next.getFeatures();
transformB = Nd4j.vstack(transformB, transformb);
}
return Transforms.abs(transformB.div(transformA).rsub(1)).maxNumber().floatValue();
}
示例13: fit
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Fit the given model
*
* @param iterator for the data to iterate over
*/
@Override
public void fit(DataSetIterator iterator) {
S.Builder featureNormBuilder = newBuilder();
S.Builder labelNormBuilder = newBuilder();
iterator.reset();
while (iterator.hasNext()) {
DataSet next = iterator.next();
featureNormBuilder.addFeatures(next);
if (fitLabels) {
labelNormBuilder.addLabels(next);
}
}
featureStats = (S) featureNormBuilder.build();
if (fitLabels) {
labelStats = (S) labelNormBuilder.build();
}
iterator.reset();
}
示例14: testPredict
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testPredict() throws Exception {
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER).seed(12345L).list()
.layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(50).nOut(10).build())
.pretrain(false).backprop(true).setInputType(InputType.convolutional(28, 28, 1)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator ds = new MnistDataSetIterator(10, 10);
net.fit(ds);
DataSetIterator testDs = new MnistDataSetIterator(1, 1);
DataSet testData = testDs.next();
testData.setLabelNames(Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"));
String actualLables = testData.getLabelName(0);
List<String> prediction = net.predict(testData);
assertTrue(actualLables != null);
assertTrue(prediction.get(0) != null);
}
示例15: output
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Label the probabilities of the input
*
* @param iterator test data to evaluate
* @return a vector of probabilities
* given each label.
* <p>
* This is typically of the form:
* [0.5, 0.5] or some other probability distribution summing to one
*/
public INDArray output(DataSetIterator iterator, boolean train) {
List<INDArray> outList = new ArrayList<>();
while (iterator.hasNext()) {
DataSet next = iterator.next();
if (next.getFeatureMatrix() == null || next.getLabels() == null)
break;
INDArray features = next.getFeatures();
if (next.hasMaskArrays()) {
INDArray fMask = next.getFeaturesMaskArray();
INDArray lMask = next.getLabelsMaskArray();
outList.add(this.output(features, train, fMask, lMask));
} else {
outList.add(output(features, train));
}
}
return Nd4j.vstack(outList.toArray(new INDArray[0]));
}