当前位置: 首页>>代码示例>>Java>>正文


Java DataSetIterator.reset方法代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.reset方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.reset方法的具体用法?Java DataSetIterator.reset怎么用?Java DataSetIterator.reset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.dataset.api.iterator.DataSetIterator的用法示例。


在下文中一共展示了DataSetIterator.reset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: assertPreProcessingGetsCached

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void assertPreProcessingGetsCached(int expectedNumberOfDataSets, DataSetIterator it,
                CachingDataSetIterator cachedIt, PreProcessor preProcessor) {

    assertSame(preProcessor, cachedIt.getPreProcessor());
    assertSame(preProcessor, it.getPreProcessor());

    cachedIt.reset();
    it.reset();

    while (cachedIt.hasNext()) {
        cachedIt.next();
    }

    assertEquals(expectedNumberOfDataSets, preProcessor.getCallCount());

    cachedIt.reset();
    it.reset();

    while (cachedIt.hasNext()) {
        cachedIt.next();
    }

    assertEquals(expectedNumberOfDataSets, preProcessor.getCallCount());
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:25,代码来源:CachingDataSetIteratorTest.java

示例2: assertCachingDataSetIteratorHasAllTheData

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void assertCachingDataSetIteratorHasAllTheData(int rows, int inputColumns, int outputColumns,
                DataSet dataSet, DataSetIterator it, CachingDataSetIterator cachedIt) {
    cachedIt.reset();
    it.reset();

    dataSet.setFeatures(Nd4j.zeros(rows, inputColumns));
    dataSet.setLabels(Nd4j.ones(rows, outputColumns));

    while (it.hasNext()) {
        assertTrue(cachedIt.hasNext());

        DataSet cachedDs = cachedIt.next();
        assertEquals(1000.0, cachedDs.getFeatureMatrix().sumNumber());
        assertEquals(0.0, cachedDs.getLabels().sumNumber());

        DataSet ds = it.next();
        assertEquals(0.0, ds.getFeatureMatrix().sumNumber());
        assertEquals(20.0, ds.getLabels().sumNumber());
    }

    assertFalse(cachedIt.hasNext());
    assertFalse(it.hasNext());
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:24,代码来源:CachingDataSetIteratorTest.java

示例3: fit

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Fit the given model
 *
 * @param iterator for the data to iterate over
 */
@Override
public void fit(DataSetIterator iterator) {
    S.Builder featureNormBuilder = newBuilder();
    S.Builder labelNormBuilder = newBuilder();

    iterator.reset();
    while (iterator.hasNext()) {
        DataSet next = iterator.next();
        featureNormBuilder.addFeatures(next);
        if (fitLabels) {
            labelNormBuilder.addLabels(next);
        }
    }
    featureStats = (S) featureNormBuilder.build();
    if (fitLabels) {
        labelStats = (S) labelNormBuilder.build();
    }
    iterator.reset();
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:25,代码来源:AbstractDataSetNormalizer.java

示例4: testCGEvaluation

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testCGEvaluation() {

    Nd4j.getRandom().setSeed(12345);
    ComputationGraphConfiguration configuration = getIrisGraphConfiguration();
    ComputationGraph graph = new ComputationGraph(configuration);
    graph.init();

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration mlnConfig = getIrisMLNConfiguration();
    MultiLayerNetwork net = new MultiLayerNetwork(mlnConfig);
    net.init();

    DataSetIterator iris = new IrisDataSetIterator(75, 150);

    net.fit(iris);
    iris.reset();
    graph.fit(iris);

    iris.reset();
    Evaluation evalExpected = net.evaluate(iris);
    iris.reset();
    Evaluation evalActual = graph.evaluate(iris);

    assertEquals(evalExpected.accuracy(), evalActual.accuracy(), 0e-4);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:TestComputationGraphNetwork.java

示例5: testOptimizersBasicMLPBackprop

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testOptimizersBasicMLPBackprop() {
    //Basic tests of the 'does it throw an exception' variety.

    DataSetIterator iter = new IrisDataSetIterator(5, 50);

    OptimizationAlgorithm[] toTest =
                    {OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT,
                                    OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS
                    //OptimizationAlgorithm.HESSIAN_FREE	//Known to not work
                    };

    for (OptimizationAlgorithm oa : toTest) {
        MultiLayerNetwork network = new MultiLayerNetwork(getMLPConfigIris(oa));
        network.init();

        iter.reset();
        network.fit(iter);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:TestOptimizers.java

示例6: testNormalizerPrefetchReset

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testNormalizerPrefetchReset() throws Exception {
    //Check NPE fix for: https://github.com/deeplearning4j/deeplearning4j/issues/4214
    RecordReader csv = new CSVRecordReader();
    csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));

    int batchSize = 3;

    DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, 4, 4, true);

    DataNormalization normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(iter);
    iter.setPreProcessor(normalizer);

    iter.inputColumns();    //Prefetch
    iter.totalOutcomes();
    iter.hasNext();
    iter.reset();
    iter.next();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:RecordReaderDataSetiteratorTest.java

示例7: testInitializeNoNextIter

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testInitializeNoNextIter() {

    DataSetIterator iter = new IrisDataSetIterator(10, 150);
    while (iter.hasNext())
        iter.next();

    DataSetIterator async = new AsyncDataSetIterator(iter, 2);

    assertFalse(iter.hasNext());
    assertFalse(async.hasNext());
    try {
        iter.next();
        fail("Should have thrown NoSuchElementException");
    } catch (Exception e) {
        //OK
    }

    async.reset();
    int count = 0;
    while (async.hasNext()) {
        async.next();
        count++;
    }
    assertEquals(150 / 10, count);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:TestAsyncIterator.java

示例8: train

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Override
public void train(FederatedDataSet dataSource) {
    DataSet trainingData = (DataSet) dataSource.getNativeDataSet();
    List<DataSet> listDs = trainingData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    //Train the network on the full data set, and evaluate in periodically
    for (int i = 0; i < N_EPOCHS; i++) {
        iterator.reset();
        mNetwork.fit(iterator);
    }
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:13,代码来源:LinearModel.java

示例9: main

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public static void main(String[] args){

        //Generate the training data
        DataSetIterator iterator = getTrainingData(batchSize,rng);

        //Create the network
        int numInput = 2;
        int numOutputs = 1;
        int nHidden = 10;
        MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .weightInit(WeightInit.XAVIER)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list()
                .layer(0, new DenseLayer.Builder().nIn(numInput).nOut(nHidden)
                        .activation(Activation.TANH)
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                        .activation(Activation.IDENTITY)
                        .nIn(nHidden).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build()
        );
        net.init();
        net.setListeners(new ScoreIterationListener(1));


        //Train the network on the full data set, and evaluate in periodically
        for( int i=0; i<nEpochs; i++ ){
            iterator.reset();
            net.fit(iterator);
        }
        // Test the addition of 2 numbers (Try different numbers here)
        final INDArray input = Nd4j.create(new double[] { 0.111111, 0.3333333333333 }, new int[] { 1, 2 });
        INDArray out = net.output(input, false);
        System.out.println(out);

    }
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:41,代码来源:RegressionSum.java

示例10: main

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public static void main(final String[] args){

        //Switch these two options to do different functions with different networks
        final MathFunction fn = new SinXDivXMathFunction();
        final MultiLayerConfiguration conf = getDeepDenseLayerNetworkConfiguration();

        //Generate the training data
        final INDArray x = Nd4j.linspace(-10,10,nSamples).reshape(nSamples, 1);
        final DataSetIterator iterator = getTrainingData(x,fn,batchSize,rng);

        //Create the network
        final MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        net.setListeners(new ScoreIterationListener(1));


        //Train the network on the full data set, and evaluate in periodically
        final INDArray[] networkPredictions = new INDArray[nEpochs/ plotFrequency];
        for( int i=0; i<nEpochs; i++ ){
            iterator.reset();
            net.fit(iterator);
            if((i+1) % plotFrequency == 0) networkPredictions[i/ plotFrequency] = net.output(x, false);
        }

        //Plot the target data and the network predictions
        plot(fn,x,fn.getFunctionValues(x),networkPredictions);
    }
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:28,代码来源:RegressionMathFunctions.java

示例11: getFirstBatchFeatures

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Get a peak at the features of the {@code iterator}'s first batch using the given instances.
 *
 * @return Features of the first batch
 * @throws Exception
 */
protected INDArray getFirstBatchFeatures(Instances data) throws Exception {
  final DataSetIterator it = getDataSetIterator(data, CacheMode.NONE);
  if (!it.hasNext()) {
    throw new RuntimeException("Iterator was unexpectedly empty.");
  }
  final INDArray features = it.next().getFeatures();
  it.reset();
  return features;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:16,代码来源:Dl4jMlpClassifier.java

示例12: main

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: 学習モデル名
 * args[2] input: train/test親フォルダ名
 * args[3] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 1;

  System.out.println("Starting online training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    System.out.println("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    System.out.println(evaluation.stats());

    System.out.println("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:47,代码来源:SentimentRecurrentTrainOnlineCmd.java

示例13: testRocMultiToHtml

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testRocMultiToHtml() throws Exception {
    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
                                    new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
                                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    NormalizerStandardize ns = new NormalizerStandardize();
    DataSet ds = iter.next();
    ns.fit(ds);
    ns.transform(ds);

    for (int i = 0; i < 30; i++) {
        net.fit(ds);
    }

    for (int numSteps : new int[] {20, 0}) {
        ROCMultiClass roc = new ROCMultiClass(numSteps);
        iter.reset();

        INDArray f = ds.getFeatures();
        INDArray l = ds.getLabels();
        INDArray out = net.output(f);
        roc.eval(l, out);


        String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
        System.out.println(str);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:36,代码来源:EvaluationToolsTests.java

示例14: testCNNMLNPretrain

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testCNNMLNPretrain() throws Exception {
    // Note CNN does not do pretrain
    int numSamples = 10;
    int batchSize = 10;
    DataSetIterator mnistIter = new MnistDataSetIterator(batchSize, numSamples, true);

    MultiLayerNetwork model = getCNNMLNConfig(false, true);
    model.fit(mnistIter);

    mnistIter.reset();

    MultiLayerNetwork model2 = getCNNMLNConfig(false, true);
    model2.fit(mnistIter);
    mnistIter.reset();

    DataSet test = mnistIter.next();

    Evaluation eval = new Evaluation();
    INDArray output = model.output(test.getFeatureMatrix());
    eval.eval(test.getLabels(), output);
    double f1Score = eval.f1();

    Evaluation eval2 = new Evaluation();
    INDArray output2 = model2.output(test.getFeatureMatrix());
    eval2.eval(test.getLabels(), output2);
    double f1Score2 = eval2.f1();

    assertEquals(f1Score, f1Score2, 1e-4);


}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:33,代码来源:ConvolutionLayerTest.java

示例15: testOutput

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testOutput() throws Exception {
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER).seed(12345L).list()
                    .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build())
                    .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nIn(50).nOut(10).build())
                    .pretrain(false).backprop(true).setInputType(InputType.convolutional(28, 28, 1)).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    DataSetIterator fullData = new MnistDataSetIterator(1, 2);
    net.fit(fullData);


    fullData.reset();
    DataSet expectedSet = fullData.next(2);
    INDArray expectedOut = net.output(expectedSet.getFeatureMatrix(), false);

    fullData.reset();

    INDArray actualOut = net.output(fullData);

    assertEquals(expectedOut, actualOut);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:28,代码来源:MultiLayerTest.java


注:本文中的org.nd4j.linalg.dataset.api.iterator.DataSetIterator.reset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。