本文整理汇总了Java中org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator.hasNext方法的典型用法代码示例。如果您正苦于以下问题:Java MnistDataSetIterator.hasNext方法的具体用法?Java MnistDataSetIterator.hasNext怎么用?Java MnistDataSetIterator.hasNext使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
的用法示例。
在下文中一共展示了MnistDataSetIterator.hasNext方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: main
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入方法依赖的package包/类
/**
* @param args
*/
public static void main(String[] args) throws Exception {
MnistDataSetIterator iter = new MnistDataSetIterator(60,60000);
@SuppressWarnings("unchecked")
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));
BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
DataSet test = null;
while(iter.hasNext()) {
INDArray reconstructed = network.transform(test.getFeatureMatrix());
for(int i = 0; i < test.numExamples(); i++) {
INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
INDArray reconstructed2 = reconstructed.getRow(i);
INDArray draw2 = Sampling.binomial(reconstructed2, 1, new MersenneTwister(123)).mul(255);
DrawReconstruction d = new DrawReconstruction(draw1);
d.title = "REAL";
d.draw();
DrawReconstruction d2 = new DrawReconstruction(draw2,100,100);
d2.title = "TEST";
d2.draw();
Thread.sleep(10000);
d.frame.dispose();
d2.frame.dispose();
}
}
}
示例2: main
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入方法依赖的package包/类
/**
* @param args
*/
public static void main(String[] args) throws Exception {
MnistDataSetIterator iter = new MnistDataSetIterator(60, 60000);
@SuppressWarnings("unchecked")
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(args[0]));
BasePretrainNetwork network = (BasePretrainNetwork) ois.readObject();
try {
ois.close();
} catch (IOException e) {
}
DataSet test = null;
while (iter.hasNext()) {
test = iter.next();
INDArray reconstructed = network.activate(test.getFeatureMatrix());
for (int i = 0; i < test.numExamples(); i++) {
INDArray draw1 = test.get(i).getFeatureMatrix().mul(255);
INDArray reconstructed2 = reconstructed.getRow(i);
INDArray draw2 = Nd4j.getDistributions().createBinomial(1, reconstructed2)
.sample(reconstructed2.shape()).mul(255);
DrawReconstruction d = new DrawReconstruction(draw1);
d.title = "REAL";
d.draw();
DrawReconstruction d2 = new DrawReconstruction(draw2, 100, 100);
d2.title = "TEST";
d2.draw();
Thread.sleep(10000);
d.frame.dispose();
d2.frame.dispose();
}
}
}
示例3: DeepAutoEncoderExample
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; //导入方法依赖的package包/类
public DeepAutoEncoderExample() {
try {
int seed = 123;
int numberOfIterations = 1;
iterator = new MnistDataSetIterator(1000, MnistDataFetcher.NUM_EXAMPLES, true);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(numberOfIterations)
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
.list()
.layer(0, new RBM.Builder().nIn(numberOfRows * numberOfColumns)
.nOut(1000)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(1, new RBM.Builder().nIn(1000).nOut(500)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(2, new RBM.Builder().nIn(500).nOut(250)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(3, new RBM.Builder().nIn(250).nOut(100)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(4, new RBM.Builder().nIn(100).nOut(30)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //encoding stops
.layer(5, new RBM.Builder().nIn(30).nOut(100)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //decoding starts
.layer(6, new RBM.Builder().nIn(100).nOut(250)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(7, new RBM.Builder().nIn(250).nOut(500)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(8, new RBM.Builder().nIn(500).nOut(1000)
.lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
.layer(9, new OutputLayer.Builder(
LossFunctions.LossFunction.RMSE_XENT).nIn(1000)
.nOut(numberOfRows * numberOfColumns).build())
.pretrain(true).backprop(true)
.build();
model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Collections.singletonList(
(IterationListener) new ScoreIterationListener()));
while (iterator.hasNext()) {
DataSet dataSet = iterator.next();
model.fit(new DataSet(dataSet.getFeatureMatrix(),
dataSet.getFeatureMatrix()));
}
modelFile = new File("savedModel");
ModelSerializer.writeModel(model, modelFile, true);
} catch (IOException ex) {
ex.printStackTrace();
}
}
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:55,代码来源:DeepAutoEncoderExample.java