当前位置: 首页>>代码示例>>Java>>正文


Java ListDataSetIterator类代码示例

本文整理汇总了Java中org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator的典型用法代码示例。如果您正苦于以下问题:Java ListDataSetIterator类的具体用法?Java ListDataSetIterator怎么用?Java ListDataSetIterator使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


ListDataSetIterator类属于org.deeplearning4j.datasets.iterator.impl包,在下文中一共展示了ListDataSetIterator类的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: getTrainingData

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
private static DataSetIterator getTrainingData(int batchSize, Random rand){
    double [] sum = new double[nSamples];
    double [] input1 = new double[nSamples];
    double [] input2 = new double[nSamples];
    for (int i= 0; i< nSamples; i++) {
        int MIN_RANGE = 0;
        int MAX_RANGE = 3;
        input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        input2[i] =  MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        sum[i] = input1[i] + input2[i];
    }
    INDArray inputNDArray1 = Nd4j.create(input1, new int[]{nSamples,1});
    INDArray inputNDArray2 = Nd4j.create(input2, new int[]{nSamples,1});
    INDArray inputNDArray = Nd4j.hstack(inputNDArray1,inputNDArray2);
    INDArray outPut = Nd4j.create(sum, new int[]{nSamples, 1});
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    List<DataSet> listDs = dataSet.asList();
    Collections.shuffle(listDs,rng);
    return new ListDataSetIterator(listDs,batchSize);

}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:22,代码来源:RegressionSum.java

示例2: train

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
@Override
public void train(FederatedDataSet dataSource) {
    DataSet trainingData = (DataSet) dataSource.getNativeDataSet();
    List<DataSet> listDs = trainingData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    //Train the network on the full data set, and evaluate in periodically
    for (int i = 0; i < N_EPOCHS; i++) {
        iterator.reset();
        mNetwork.fit(iterator);
    }
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:13,代码来源:LinearModel.java

示例3: evaluate

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
    DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = testData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    return mNetwork.evaluate(iterator).stats();
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:9,代码来源:LinearModel.java

示例4: evaluate

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
    DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = testData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    Evaluation eval = new Evaluation(OUTPUT_NUM); //create an evaluation object with 10 possible classes
    while (iterator.hasNext()) {
        DataSet next = iterator.next();
        INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction
        eval.eval(next.getLabels(), output); //check the prediction against the true class
    }

    return eval.stats();
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:16,代码来源:MNISTModel.java

示例5: train

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
@Override
public void train(FederatedDataSet federatedDataSet) {
    DataSet trainingData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = trainingData.asList();
    DataSetIterator mnistTrain = new ListDataSetIterator(listDs, BATCH_SIZE);
    for (int i = 0; i < N_EPOCHS; i++) {
        model.fit(mnistTrain);
    }
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:10,代码来源:MNISTModel.java

示例6: getTrainingData

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
/** Create a DataSetIterator for training
 * @param x X values
 * @param function Function to evaluate
 * @param batchSize Batch size (number of examples for every call of DataSetIterator.next())
 * @param rng Random number generator (for repeatability)
 */
private static DataSetIterator getTrainingData(final INDArray x, final MathFunction function, final int batchSize, final Random rng) {
    final INDArray y = function.getFunctionValues(x);
    final DataSet allData = new DataSet(x,y);

    final List<DataSet> list = allData.asList();
    Collections.shuffle(list,rng);
    return new ListDataSetIterator(list,batchSize);
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:15,代码来源:RegressionMathFunctions.java

示例7: getDS

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
protected static DataSetIterator getDS() {

        List<DataSet> list = new ArrayList<>(5);
        for (int i = 0; i < 5; i++) {
            INDArray f = Nd4j.create(1, 32 * 32 * 3);
            INDArray l = Nd4j.create(1, 10);
            l.putScalar(i, 1.0);
            list.add(new DataSet(f, l));
        }
        return new ListDataSetIterator(list, 5);
    }
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:12,代码来源:TestCompGraphCNN.java

示例8: main

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
    SentenceIterator docIter = new CollectionSentenceIterator(new SentenceToPhraseMapper(new ClassPathResource("/train.tsv").getFile()).sentences());
    TokenizerFactory factory = new DefaultTokenizerFactory();
    Word2Vec  vec = new Word2Vec.Builder().iterate(docIter).tokenizerFactory(factory).batchSize(100000)
            .learningRate(2.5e-2).iterations(1)
            .layerSize(100).windowSize(5).build();
    vec.fit();

    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().nIn(vec.getLayerSize()).nOut(vec.getLayerSize())
            .hiddenUnit(RBM.HiddenUnit.RECTIFIED).visibleUnit(RBM.VisibleUnit.GAUSSIAN).momentum(0.5f)
            .iterations(10).learningRate(1e-6f).build();

    InMemoryLookupCache l = (InMemoryLookupCache) vec.getCache();

    DBN d = new DBN.Builder()
            .configure(conf).hiddenLayerSizes(new int[]{250,100,2})
            .build();
    DataSet dPretrain = new DataSet(l.getSyn0(),l.getSyn0());
    DataSetIterator dPretrainIter =  new ListDataSetIterator(dPretrain.asList(),1000);
    while(dPretrainIter.hasNext()) {
        d.pretrain(dPretrainIter.next().getFeatureMatrix(), 1, 1e-6f, 10);


    }

    // d.pretrain(l.getSyn0(),1,1e-3f,1000);
    d.getOutputLayer().conf().setLossFunction(LossFunctions.LossFunction.RMSE_XENT);

    SemanticHashing s = new SemanticHashing.Builder().withEncoder(d)
            .build();

    d = null;

    dPretrainIter.reset();
    while(dPretrainIter.hasNext()) {
        s.fit(dPretrainIter.next());

    }




    Tsne t = new Tsne.Builder()
            .setMaxIter(100).stopLyingIteration(20).build();

    INDArray output = s.reconstruct(l.getSyn0(),4);
    l.getSyn0().data().flush();
    l.getSyn1().data().flush();
    s = null;
    System.out.println(Arrays.toString(output.shape()));
    t.plot(output,2,new ArrayList<>(vec.getCache().words()));
    vec.getCache().plotVocab(t);

}
 
开发者ID:ihuerga,项目名称:deeplearning4j-nlp-examples,代码行数:55,代码来源:VisualizationSemanticHashing.java

示例9: testIris

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
@Test
public void testIris() {

    // Network config
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

                    .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42)
                    .updater(new Sgd(1e-6)).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH)
                                    .weightInit(WeightInit.XAVIER).build())
                    .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                    LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER)
                                                    .activation(Activation.SOFTMAX).build())

                    .build();

    // Instantiate model
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(1)));

    // Train-test split
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    DataSet next = iter.next();
    next.shuffle();
    SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42));

    // Train
    DataSet train = trainTest.getTrain();
    train.normalizeZeroMeanZeroUnitVariance();

    // Test
    DataSet test = trainTest.getTest();
    test.normalizeZeroMeanZeroUnitVariance();
    INDArray testFeature = test.getFeatureMatrix();
    INDArray testLabel = test.getLabels();

    // Fitting model
    model.fit(train);
    // Get predictions from test feature
    INDArray testPredictedLabel = model.output(testFeature);

    // Eval with class number
    Evaluation eval = new Evaluation(3); //// Specify class num here
    eval.eval(testLabel, testPredictedLabel);
    double eval1F1 = eval.f1();
    double eval1Acc = eval.accuracy();

    // Eval without class number
    Evaluation eval2 = new Evaluation(); //// No class num
    eval2.eval(testLabel, testPredictedLabel);
    double eval2F1 = eval2.f1();
    double eval2Acc = eval2.accuracy();

    //Assert the two implementations give same f1 and accuracy (since one batch)
    assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc);

    Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test)));
    checkEvaluationEquality(eval, evalViaMethod);

    System.out.println(eval.getConfusionMatrix().toString());
    System.out.println(eval.getConfusionMatrix().toCSV());
    System.out.println(eval.getConfusionMatrix().toHTML());

    System.out.println(eval.confusionToString());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:67,代码来源:EvalTest.java

示例10: testEvaluationAndRoc

import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //导入依赖的package包/类
@Test
public void testEvaluationAndRoc() {
    DataSetIterator iter = new IrisDataSetIterator(5, 150);

    //Make a 2-class version of iris:
    List<DataSet> l = new ArrayList<>();
    iter.reset();
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray newL = Nd4j.create(ds.getLabels().size(0), 2);
        newL.putColumn(0, ds.getLabels().getColumn(0));
        newL.putColumn(1, ds.getLabels().getColumn(1));
        newL.getColumn(1).addi(ds.getLabels().getColumn(2));
        ds.setLabels(newL);
        l.add(ds);
    }

    iter = new ListDataSetIterator<>(l);

    ComputationGraph cg = getBasicNetIris2Class();

    Evaluation e = cg.evaluate(iter);
    ROC roc = cg.evaluateROC(iter, 32);


    SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null);



    JavaRDD<DataSet> rdd = sc.parallelize(l);
    rdd = rdd.repartition(20);

    Evaluation e2 = scg.evaluate(rdd);
    ROC roc2 = scg.evaluateROC(rdd);


    assertEquals(e2.accuracy(), e.accuracy(), 1e-3);
    assertEquals(e2.f1(), e.f1(), 1e-3);
    assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3);
    assertEquals(e2.falseNegatives(), e.falseNegatives());
    assertEquals(e2.falsePositives(), e.falsePositives());
    assertEquals(e2.trueNegatives(), e.trueNegatives());
    assertEquals(e2.truePositives(), e.truePositives());
    assertEquals(e2.precision(), e.precision(), 1e-3);
    assertEquals(e2.recall(), e.recall(), 1e-3);
    assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix());

    assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5);
    assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:51,代码来源:TestSparkComputationGraph.java


注:本文中的org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。