当前位置: 首页>>代码示例>>Java>>正文


Java DataSetIterator.hasNext方法代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.hasNext方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.hasNext方法的具体用法?Java DataSetIterator.hasNext怎么用?Java DataSetIterator.hasNext使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.dataset.api.iterator.DataSetIterator的用法示例。


在下文中一共展示了DataSetIterator.hasNext方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: evalMnistTestSet

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private static void evalMnistTestSet(MultiLayerNetwork leNetModel) throws Exception {
	
       log.info("Load test data....");
       int batchSize = 64;
       DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
	
       log.info("Evaluate model....");
       int outputNum = 10;
       Evaluation eval = new Evaluation(outputNum);
	
       while(mnistTest.hasNext()){
           DataSet dataSet = mnistTest.next();
           INDArray output = leNetModel.output(dataSet.getFeatureMatrix(), false);
           eval.eval(dataSet.getLabels(), output);
       }
	
       log.info(eval.stats());
}
 
开发者ID:matthiaszimmermann,项目名称:ml_demo,代码行数:19,代码来源:LeNetMnistTester.java

示例2: main

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: sentimentモデル名
 * args[2] input: test親フォルダ名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);

  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
  Evaluation evaluation = new Evaluation();
  while(test.hasNext()) {
    DataSet t = test.next();
    INDArray features = t.getFeatures();
    INDArray lables = t.getLabels();
    INDArray inMask = t.getFeaturesMaskArray();
    INDArray outMask = t.getLabelsMaskArray();
    INDArray predicted = model.output(features,false,inMask,outMask);
    evaluation.evalTimeSeries(lables,predicted,outMask);
  }
  System.out.println(evaluation.stats());
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:30,代码来源:SentimentRecurrentTestCmd.java

示例3: assertCachingDataSetIteratorHasAllTheData

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void assertCachingDataSetIteratorHasAllTheData(int rows, int inputColumns, int outputColumns,
                DataSet dataSet, DataSetIterator it, CachingDataSetIterator cachedIt) {
    cachedIt.reset();
    it.reset();

    dataSet.setFeatures(Nd4j.zeros(rows, inputColumns));
    dataSet.setLabels(Nd4j.ones(rows, outputColumns));

    while (it.hasNext()) {
        assertTrue(cachedIt.hasNext());

        DataSet cachedDs = cachedIt.next();
        assertEquals(1000.0, cachedDs.getFeatureMatrix().sumNumber());
        assertEquals(0.0, cachedDs.getLabels().sumNumber());

        DataSet ds = it.next();
        assertEquals(0.0, ds.getFeatureMatrix().sumNumber());
        assertEquals(20.0, ds.getLabels().sumNumber());
    }

    assertFalse(cachedIt.hasNext());
    assertFalse(it.hasNext());
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:24,代码来源:CachingDataSetIteratorTest.java

示例4: testItervsDataset

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public float testItervsDataset(DataNormalization preProcessor) {
    DataSet dataCopy = data.copy();
    DataSetIterator dataIter = new TestDataSetIterator(dataCopy, batchSize);
    preProcessor.fit(dataCopy);
    preProcessor.transform(dataCopy);
    INDArray transformA = dataCopy.getFeatures();

    preProcessor.fit(dataIter);
    dataIter.setPreProcessor(preProcessor);
    DataSet next = dataIter.next();
    INDArray transformB = next.getFeatures();

    while (dataIter.hasNext()) {
        next = dataIter.next();
        INDArray transformb = next.getFeatures();
        transformB = Nd4j.vstack(transformB, transformb);
    }

    return Transforms.abs(transformB.div(transformA).rsub(1)).maxNumber().floatValue();
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:21,代码来源:NormalizerTests.java

示例5: fit

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Fit the given model
 *
 * @param iterator for the data to iterate over
 */
@Override
public void fit(DataSetIterator iterator) {
    S.Builder featureNormBuilder = newBuilder();
    S.Builder labelNormBuilder = newBuilder();

    iterator.reset();
    while (iterator.hasNext()) {
        DataSet next = iterator.next();
        featureNormBuilder.addFeatures(next);
        if (fitLabels) {
            labelNormBuilder.addLabels(next);
        }
    }
    featureStats = (S) featureNormBuilder.build();
    if (fitLabels) {
        labelStats = (S) labelNormBuilder.build();
    }
    iterator.reset();
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:25,代码来源:AbstractDataSetNormalizer.java

示例6: testNormalizerPrefetchReset

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testNormalizerPrefetchReset() throws Exception {
    //Check NPE fix for: https://github.com/deeplearning4j/deeplearning4j/issues/4214
    RecordReader csv = new CSVRecordReader();
    csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));

    int batchSize = 3;

    DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, 4, 4, true);

    DataNormalization normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(iter);
    iter.setPreProcessor(normalizer);

    iter.inputColumns();    //Prefetch
    iter.totalOutcomes();
    iter.hasNext();
    iter.reset();
    iter.next();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:RecordReaderDataSetiteratorTest.java

示例7: testInitializeNoNextIter

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testInitializeNoNextIter() {

    DataSetIterator iter = new IrisDataSetIterator(10, 150);
    while (iter.hasNext())
        iter.next();

    DataSetIterator async = new AsyncDataSetIterator(iter, 2);

    assertFalse(iter.hasNext());
    assertFalse(async.hasNext());
    try {
        iter.next();
        fail("Should have thrown NoSuchElementException");
    } catch (Exception e) {
        //OK
    }

    async.reset();
    int count = 0;
    while (async.hasNext()) {
        async.next();
        count++;
    }
    assertEquals(150 / 10, count);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:TestAsyncIterator.java

示例8: check

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void check(BufferedImage image) throws Exception
{
    ImageIO.write(image, "png", new File("tmp.png")); //saves the image to the tmp.png file
    ImageRecordReader reader = new ImageRecordReader(150, 150, 3);
    reader.initialize(new FileSplit(new File("tmp.png"))); //reads the tmp.png file
    DataSetIterator dataIter = new RecordReaderDataSetIterator(reader, 1);
    while (dataIter.hasNext())
    {
        //Normalize the data from the file
        DataNormalization normalization = new NormalizerMinMaxScaler();
        DataSet set = dataIter.next();
        normalization.fit(set);
        normalization.transform(set);

        INDArray array = MainGUI.model.output(set.getFeatures(), false); //send the data to the model and get the results

        //Process the results and print them in an understandable format (percentage scores)
        String txt = "";

        DecimalFormat df = new DecimalFormat("#.00");

        for (int i = 0; i < array.length(); i++)
        {
            txt += MainGUI.labels.get(i) + ": " + (array.getDouble(i)*100 < 1 ? "0" : "") + df.format((array.getDouble(i)*100)) + "%\n";
        }

        probabilityArea.setText(txt);
    }

    reader.close();
}
 
开发者ID:maksgraczyk,项目名称:DeepID,代码行数:32,代码来源:Identification.java

示例9: evaluate

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
    DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = testData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    Evaluation eval = new Evaluation(OUTPUT_NUM); //create an evaluation object with 10 possible classes
    while (iterator.hasNext()) {
        DataSet next = iterator.next();
        INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction
        eval.eval(next.getLabels(), output); //check the prediction against the true class
    }

    return eval.stats();
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:16,代码来源:MNISTModel.java

示例10: toCsv

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private String toCsv(DataSetIterator it, List<Integer> labels, int[] shape) {
    if (it.numExamples() != labels.size()) {
        throw new IllegalStateException(
                String.format("numExamples == %d != labels.size() == %d",
                        it.numExamples(), labels.size()));
    }

    StringBuffer sb = new StringBuffer();
    int l = 0;

    while (it.hasNext()) {
        INDArray features = it.next(1).getFeatures();

        if (!(Arrays.equals(features.shape(), shape))) {
            throw new IllegalStateException(String.format("wrong shape: got %s, expected",
                    Arrays.toString(features.shape()), Arrays.toString(shape)));
        }

        // Prepend the label
        sb.append(labels.get(l)).append(": ");
        l++;

        for (int i=0; i<features.columns(); i++) {
            sb.append(features.getColumn(i));

            if (i < features.columns()-1) {
                sb.append(", ");
            }
        }

        sb.append("\n");
    }

    return sb.toString();
}
 
开发者ID:SkymindIO,项目名称:SKIL_Examples,代码行数:36,代码来源:NormalizeUciData.java

示例11: main

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public static void main(String[] args) throws Exception {
    final int numRows = 28;
    final int numColumns = 28;
    int seed = 123;
    int numSamples = MnistDataFetcher.NUM_EXAMPLES;
    int batchSize = 1000;
    int iterations = 1;
    int listenerFreq = iterations/5;

    log.info("Load data....");
    DataSetIterator iter = new MnistDataSetIterator(batchSize,numSamples,true);

    log.info("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .iterations(iterations)
            .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
            .list(8)
            .layer(0, new RBM.Builder().nIn(numRows * numColumns).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(1, new RBM.Builder().nIn(2000).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(2, new RBM.Builder().nIn(1000).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(3, new RBM.Builder().nIn(500).nOut(30).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(4, new RBM.Builder().nIn(30).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) 
            .layer(5, new RBM.Builder().nIn(500).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(6, new RBM.Builder().nIn(1000).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(7, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nIn(2000).nOut(numRows*numColumns).build())
            .pretrain(true).backprop(true)
            .build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    model.setListeners(new ScoreIterationListener(listenerFreq));

    log.info("Train model....");
    while(iter.hasNext()) {
        DataSet next = iter.next();
        model.fit(new DataSet(next.getFeatureMatrix(),next.getFeatureMatrix()));
    }
}
 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-Hadoop,代码行数:41,代码来源:DeepAutoEncoder.java

示例12: getFirstBatchFeatures

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Get a peak at the features of the {@code iterator}'s first batch using the given instances.
 *
 * @return Features of the first batch
 * @throws Exception
 */
protected INDArray getFirstBatchFeatures(Instances data) throws Exception {
  final DataSetIterator it = getDataSetIterator(data, CacheMode.NONE);
  if (!it.hasNext()) {
    throw new RuntimeException("Iterator was unexpectedly empty.");
  }
  final INDArray features = it.next().getFeatures();
  it.reset();
  return features;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:16,代码来源:Dl4jMlpClassifier.java

示例13: countIterations

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Counts the number of iterations
 *
 * @param data Instances to iterate
 * @param iter iterator to be tested
 * @param seed Seed
 * @param batchsize Size of the batch which is returned in {@see DataSetIterator#next}
 * @return Number of iterations
 * @throws Exception
 */
private int countIterations(
    Instances data, AbstractInstanceIterator iter, int seed, int batchsize) throws Exception {
  DataSetIterator it = iter.getDataSetIterator(data, seed, batchsize);
  int count = 0;
  while (it.hasNext()) {
    count++;
    it.next();
  }
  return count;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:21,代码来源:CnnTextFilesEmbeddingInstanceIteratorTest.java

示例14: countIterations

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Counts the number of iterations an {@see ImageInstanceIterator}
 *
 * @param data Instances to iterate
 * @param imgIter ImageInstanceIterator to be tested
 * @param seed Seed
 * @param batchsize Size of the batch which is returned in {@see DataSetIterator#next}
 * @return Number of iterations
 * @throws Exception
 */
private int countIterations(
    Instances data, ImageInstanceIterator imgIter, int seed, int batchsize) throws Exception {
  DataSetIterator it = imgIter.getDataSetIterator(data, seed, batchsize);
  int count = 0;
  while (it.hasNext()) {
    count++;
    DataSet dataset = it.next();
  }
  return count;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:21,代码来源:ImageInstanceIteratorTest.java

示例15: main

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: 学習モデル名
 * args[2] input: train/test親フォルダ名
 * args[3] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 1;

  System.out.println("Starting online training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    System.out.println("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    System.out.println(evaluation.stats());

    System.out.println("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:47,代码来源:SentimentRecurrentTrainOnlineCmd.java


注:本文中的org.nd4j.linalg.dataset.api.iterator.DataSetIterator.hasNext方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。