本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.hasNext方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.hasNext方法的具体用法?Java DataSetIterator.hasNext怎么用?Java DataSetIterator.hasNext使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.dataset.api.iterator.DataSetIterator
的用法示例。
在下文中一共展示了DataSetIterator.hasNext方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: evalMnistTestSet
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private static void evalMnistTestSet(MultiLayerNetwork leNetModel) throws Exception {
log.info("Load test data....");
int batchSize = 64;
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
log.info("Evaluate model....");
int outputNum = 10;
Evaluation eval = new Evaluation(outputNum);
while(mnistTest.hasNext()){
DataSet dataSet = mnistTest.next();
INDArray output = leNetModel.output(dataSet.getFeatureMatrix(), false);
eval.eval(dataSet.getLabels(), output);
}
log.info(eval.stats());
}
示例2: main
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* args[0] input: word2vecファイル名
* args[1] input: sentimentモデル名
* args[2] input: test親フォルダ名
*
* @param args
* @throws Exception
*/
public static void main (final String[] args) throws Exception {
if (args[0]==null || args[1]==null || args[2]==null)
System.exit(1);
WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);
DataSetIterator test = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
Evaluation evaluation = new Evaluation();
while(test.hasNext()) {
DataSet t = test.next();
INDArray features = t.getFeatures();
INDArray lables = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = model.output(features,false,inMask,outMask);
evaluation.evalTimeSeries(lables,predicted,outMask);
}
System.out.println(evaluation.stats());
}
示例3: assertCachingDataSetIteratorHasAllTheData
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void assertCachingDataSetIteratorHasAllTheData(int rows, int inputColumns, int outputColumns,
DataSet dataSet, DataSetIterator it, CachingDataSetIterator cachedIt) {
cachedIt.reset();
it.reset();
dataSet.setFeatures(Nd4j.zeros(rows, inputColumns));
dataSet.setLabels(Nd4j.ones(rows, outputColumns));
while (it.hasNext()) {
assertTrue(cachedIt.hasNext());
DataSet cachedDs = cachedIt.next();
assertEquals(1000.0, cachedDs.getFeatureMatrix().sumNumber());
assertEquals(0.0, cachedDs.getLabels().sumNumber());
DataSet ds = it.next();
assertEquals(0.0, ds.getFeatureMatrix().sumNumber());
assertEquals(20.0, ds.getLabels().sumNumber());
}
assertFalse(cachedIt.hasNext());
assertFalse(it.hasNext());
}
示例4: testItervsDataset
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public float testItervsDataset(DataNormalization preProcessor) {
DataSet dataCopy = data.copy();
DataSetIterator dataIter = new TestDataSetIterator(dataCopy, batchSize);
preProcessor.fit(dataCopy);
preProcessor.transform(dataCopy);
INDArray transformA = dataCopy.getFeatures();
preProcessor.fit(dataIter);
dataIter.setPreProcessor(preProcessor);
DataSet next = dataIter.next();
INDArray transformB = next.getFeatures();
while (dataIter.hasNext()) {
next = dataIter.next();
INDArray transformb = next.getFeatures();
transformB = Nd4j.vstack(transformB, transformb);
}
return Transforms.abs(transformB.div(transformA).rsub(1)).maxNumber().floatValue();
}
示例5: fit
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Fit the given model
*
* @param iterator for the data to iterate over
*/
@Override
public void fit(DataSetIterator iterator) {
S.Builder featureNormBuilder = newBuilder();
S.Builder labelNormBuilder = newBuilder();
iterator.reset();
while (iterator.hasNext()) {
DataSet next = iterator.next();
featureNormBuilder.addFeatures(next);
if (fitLabels) {
labelNormBuilder.addLabels(next);
}
}
featureStats = (S) featureNormBuilder.build();
if (fitLabels) {
labelStats = (S) labelNormBuilder.build();
}
iterator.reset();
}
示例6: testNormalizerPrefetchReset
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testNormalizerPrefetchReset() throws Exception {
//Check NPE fix for: https://github.com/deeplearning4j/deeplearning4j/issues/4214
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
int batchSize = 3;
DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, 4, 4, true);
DataNormalization normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fit(iter);
iter.setPreProcessor(normalizer);
iter.inputColumns(); //Prefetch
iter.totalOutcomes();
iter.hasNext();
iter.reset();
iter.next();
}
示例7: testInitializeNoNextIter
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testInitializeNoNextIter() {
DataSetIterator iter = new IrisDataSetIterator(10, 150);
while (iter.hasNext())
iter.next();
DataSetIterator async = new AsyncDataSetIterator(iter, 2);
assertFalse(iter.hasNext());
assertFalse(async.hasNext());
try {
iter.next();
fail("Should have thrown NoSuchElementException");
} catch (Exception e) {
//OK
}
async.reset();
int count = 0;
while (async.hasNext()) {
async.next();
count++;
}
assertEquals(150 / 10, count);
}
示例8: check
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void check(BufferedImage image) throws Exception
{
ImageIO.write(image, "png", new File("tmp.png")); //saves the image to the tmp.png file
ImageRecordReader reader = new ImageRecordReader(150, 150, 3);
reader.initialize(new FileSplit(new File("tmp.png"))); //reads the tmp.png file
DataSetIterator dataIter = new RecordReaderDataSetIterator(reader, 1);
while (dataIter.hasNext())
{
//Normalize the data from the file
DataNormalization normalization = new NormalizerMinMaxScaler();
DataSet set = dataIter.next();
normalization.fit(set);
normalization.transform(set);
INDArray array = MainGUI.model.output(set.getFeatures(), false); //send the data to the model and get the results
//Process the results and print them in an understandable format (percentage scores)
String txt = "";
DecimalFormat df = new DecimalFormat("#.00");
for (int i = 0; i < array.length(); i++)
{
txt += MainGUI.labels.get(i) + ": " + (array.getDouble(i)*100 < 1 ? "0" : "") + df.format((array.getDouble(i)*100)) + "%\n";
}
probabilityArea.setText(txt);
}
reader.close();
}
示例9: evaluate
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
List<DataSet> listDs = testData.asList();
DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);
Evaluation eval = new Evaluation(OUTPUT_NUM); //create an evaluation object with 10 possible classes
while (iterator.hasNext()) {
DataSet next = iterator.next();
INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction
eval.eval(next.getLabels(), output); //check the prediction against the true class
}
return eval.stats();
}
示例10: toCsv
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private String toCsv(DataSetIterator it, List<Integer> labels, int[] shape) {
if (it.numExamples() != labels.size()) {
throw new IllegalStateException(
String.format("numExamples == %d != labels.size() == %d",
it.numExamples(), labels.size()));
}
StringBuffer sb = new StringBuffer();
int l = 0;
while (it.hasNext()) {
INDArray features = it.next(1).getFeatures();
if (!(Arrays.equals(features.shape(), shape))) {
throw new IllegalStateException(String.format("wrong shape: got %s, expected",
Arrays.toString(features.shape()), Arrays.toString(shape)));
}
// Prepend the label
sb.append(labels.get(l)).append(": ");
l++;
for (int i=0; i<features.columns(); i++) {
sb.append(features.getColumn(i));
if (i < features.columns()-1) {
sb.append(", ");
}
}
sb.append("\n");
}
return sb.toString();
}
示例11: main
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public static void main(String[] args) throws Exception {
final int numRows = 28;
final int numColumns = 28;
int seed = 123;
int numSamples = MnistDataFetcher.NUM_EXAMPLES;
int batchSize = 1000;
int iterations = 1;
int listenerFreq = iterations/5;
log.info("Load data....");
DataSetIterator iter = new MnistDataSetIterator(batchSize,numSamples,true);
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.list(8)
.layer(0, new RBM.Builder().nIn(numRows * numColumns).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(1, new RBM.Builder().nIn(2000).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(2, new RBM.Builder().nIn(1000).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(3, new RBM.Builder().nIn(500).nOut(30).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(4, new RBM.Builder().nIn(30).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(5, new RBM.Builder().nIn(500).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(6, new RBM.Builder().nIn(1000).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(7, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nIn(2000).nOut(numRows*numColumns).build())
.pretrain(true).backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(listenerFreq));
log.info("Train model....");
while(iter.hasNext()) {
DataSet next = iter.next();
model.fit(new DataSet(next.getFeatureMatrix(),next.getFeatureMatrix()));
}
}
示例12: getFirstBatchFeatures
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Get a peak at the features of the {@code iterator}'s first batch using the given instances.
*
* @return Features of the first batch
* @throws Exception
*/
protected INDArray getFirstBatchFeatures(Instances data) throws Exception {
final DataSetIterator it = getDataSetIterator(data, CacheMode.NONE);
if (!it.hasNext()) {
throw new RuntimeException("Iterator was unexpectedly empty.");
}
final INDArray features = it.next().getFeatures();
it.reset();
return features;
}
示例13: countIterations
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Counts the number of iterations
*
* @param data Instances to iterate
* @param iter iterator to be tested
* @param seed Seed
* @param batchsize Size of the batch which is returned in {@see DataSetIterator#next}
* @return Number of iterations
* @throws Exception
*/
private int countIterations(
Instances data, AbstractInstanceIterator iter, int seed, int batchsize) throws Exception {
DataSetIterator it = iter.getDataSetIterator(data, seed, batchsize);
int count = 0;
while (it.hasNext()) {
count++;
it.next();
}
return count;
}
示例14: countIterations
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Counts the number of iterations an {@see ImageInstanceIterator}
*
* @param data Instances to iterate
* @param imgIter ImageInstanceIterator to be tested
* @param seed Seed
* @param batchsize Size of the batch which is returned in {@see DataSetIterator#next}
* @return Number of iterations
* @throws Exception
*/
private int countIterations(
Instances data, ImageInstanceIterator imgIter, int seed, int batchsize) throws Exception {
DataSetIterator it = imgIter.getDataSetIterator(data, seed, batchsize);
int count = 0;
while (it.hasNext()) {
count++;
DataSet dataset = it.next();
}
return count;
}
示例15: main
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* args[0] input: word2vecファイル名
* args[1] input: 学習モデル名
* args[2] input: train/test親フォルダ名
* args[3] output: 学習モデル名
*
* @param args
* @throws Exception
*/
public static void main (final String[] args) throws Exception {
if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
System.exit(1);
WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
int batchSize = 16;//100;
int testBatch = 64;
int nEpochs = 1;
System.out.println("Starting online training");
DataSetIterator train = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
DataSetIterator test = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
for( int i=0; i<nEpochs; i++ ){
model.fit(train);
train.reset();
System.out.println("Epoch " + i + " complete. Starting evaluation:");
Evaluation evaluation = new Evaluation();
while(test.hasNext()) {
DataSet t = test.next();
INDArray features = t.getFeatures();
INDArray lables = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = model.output(features,false,inMask,outMask);
evaluation.evalTimeSeries(lables,predicted,outMask);
}
test.reset();
System.out.println(evaluation.stats());
System.out.println("Save model");
ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
}
}