本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.reset方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.reset方法的具体用法?Java DataSetIterator.reset怎么用?Java DataSetIterator.reset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.dataset.api.iterator.DataSetIterator
的用法示例。
在下文中一共展示了DataSetIterator.reset方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: assertPreProcessingGetsCached
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void assertPreProcessingGetsCached(int expectedNumberOfDataSets, DataSetIterator it,
CachingDataSetIterator cachedIt, PreProcessor preProcessor) {
assertSame(preProcessor, cachedIt.getPreProcessor());
assertSame(preProcessor, it.getPreProcessor());
cachedIt.reset();
it.reset();
while (cachedIt.hasNext()) {
cachedIt.next();
}
assertEquals(expectedNumberOfDataSets, preProcessor.getCallCount());
cachedIt.reset();
it.reset();
while (cachedIt.hasNext()) {
cachedIt.next();
}
assertEquals(expectedNumberOfDataSets, preProcessor.getCallCount());
}
示例2: assertCachingDataSetIteratorHasAllTheData
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
private void assertCachingDataSetIteratorHasAllTheData(int rows, int inputColumns, int outputColumns,
DataSet dataSet, DataSetIterator it, CachingDataSetIterator cachedIt) {
cachedIt.reset();
it.reset();
dataSet.setFeatures(Nd4j.zeros(rows, inputColumns));
dataSet.setLabels(Nd4j.ones(rows, outputColumns));
while (it.hasNext()) {
assertTrue(cachedIt.hasNext());
DataSet cachedDs = cachedIt.next();
assertEquals(1000.0, cachedDs.getFeatureMatrix().sumNumber());
assertEquals(0.0, cachedDs.getLabels().sumNumber());
DataSet ds = it.next();
assertEquals(0.0, ds.getFeatureMatrix().sumNumber());
assertEquals(20.0, ds.getLabels().sumNumber());
}
assertFalse(cachedIt.hasNext());
assertFalse(it.hasNext());
}
示例3: fit
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Fit the given model
*
* @param iterator for the data to iterate over
*/
@Override
public void fit(DataSetIterator iterator) {
S.Builder featureNormBuilder = newBuilder();
S.Builder labelNormBuilder = newBuilder();
iterator.reset();
while (iterator.hasNext()) {
DataSet next = iterator.next();
featureNormBuilder.addFeatures(next);
if (fitLabels) {
labelNormBuilder.addLabels(next);
}
}
featureStats = (S) featureNormBuilder.build();
if (fitLabels) {
labelStats = (S) labelNormBuilder.build();
}
iterator.reset();
}
示例4: testCGEvaluation
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testCGEvaluation() {
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration configuration = getIrisGraphConfiguration();
ComputationGraph graph = new ComputationGraph(configuration);
graph.init();
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration mlnConfig = getIrisMLNConfiguration();
MultiLayerNetwork net = new MultiLayerNetwork(mlnConfig);
net.init();
DataSetIterator iris = new IrisDataSetIterator(75, 150);
net.fit(iris);
iris.reset();
graph.fit(iris);
iris.reset();
Evaluation evalExpected = net.evaluate(iris);
iris.reset();
Evaluation evalActual = graph.evaluate(iris);
assertEquals(evalExpected.accuracy(), evalActual.accuracy(), 0e-4);
}
示例5: testOptimizersBasicMLPBackprop
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testOptimizersBasicMLPBackprop() {
//Basic tests of the 'does it throw an exception' variety.
DataSetIterator iter = new IrisDataSetIterator(5, 50);
OptimizationAlgorithm[] toTest =
{OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT,
OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS
//OptimizationAlgorithm.HESSIAN_FREE //Known to not work
};
for (OptimizationAlgorithm oa : toTest) {
MultiLayerNetwork network = new MultiLayerNetwork(getMLPConfigIris(oa));
network.init();
iter.reset();
network.fit(iter);
}
}
示例6: testNormalizerPrefetchReset
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testNormalizerPrefetchReset() throws Exception {
//Check NPE fix for: https://github.com/deeplearning4j/deeplearning4j/issues/4214
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
int batchSize = 3;
DataSetIterator iter = new RecordReaderDataSetIterator(csv, batchSize, 4, 4, true);
DataNormalization normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fit(iter);
iter.setPreProcessor(normalizer);
iter.inputColumns(); //Prefetch
iter.totalOutcomes();
iter.hasNext();
iter.reset();
iter.next();
}
示例7: testInitializeNoNextIter
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testInitializeNoNextIter() {
DataSetIterator iter = new IrisDataSetIterator(10, 150);
while (iter.hasNext())
iter.next();
DataSetIterator async = new AsyncDataSetIterator(iter, 2);
assertFalse(iter.hasNext());
assertFalse(async.hasNext());
try {
iter.next();
fail("Should have thrown NoSuchElementException");
} catch (Exception e) {
//OK
}
async.reset();
int count = 0;
while (async.hasNext()) {
async.next();
count++;
}
assertEquals(150 / 10, count);
}
示例8: train
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Override
public void train(FederatedDataSet dataSource) {
DataSet trainingData = (DataSet) dataSource.getNativeDataSet();
List<DataSet> listDs = trainingData.asList();
DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);
//Train the network on the full data set, and evaluate in periodically
for (int i = 0; i < N_EPOCHS; i++) {
iterator.reset();
mNetwork.fit(iterator);
}
}
示例9: main
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public static void main(String[] args){
//Generate the training data
DataSetIterator iterator = getTrainingData(batchSize,rng);
//Create the network
int numInput = 2;
int numOutputs = 1;
int nHidden = 10;
MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.weightInit(WeightInit.XAVIER)
.updater(Updater.NESTEROVS).momentum(0.9)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInput).nOut(nHidden)
.activation(Activation.TANH)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(nHidden).nOut(numOutputs).build())
.pretrain(false).backprop(true).build()
);
net.init();
net.setListeners(new ScoreIterationListener(1));
//Train the network on the full data set, and evaluate in periodically
for( int i=0; i<nEpochs; i++ ){
iterator.reset();
net.fit(iterator);
}
// Test the addition of 2 numbers (Try different numbers here)
final INDArray input = Nd4j.create(new double[] { 0.111111, 0.3333333333333 }, new int[] { 1, 2 });
INDArray out = net.output(input, false);
System.out.println(out);
}
示例10: main
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public static void main(final String[] args){
//Switch these two options to do different functions with different networks
final MathFunction fn = new SinXDivXMathFunction();
final MultiLayerConfiguration conf = getDeepDenseLayerNetworkConfiguration();
//Generate the training data
final INDArray x = Nd4j.linspace(-10,10,nSamples).reshape(nSamples, 1);
final DataSetIterator iterator = getTrainingData(x,fn,batchSize,rng);
//Create the network
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
//Train the network on the full data set, and evaluate in periodically
final INDArray[] networkPredictions = new INDArray[nEpochs/ plotFrequency];
for( int i=0; i<nEpochs; i++ ){
iterator.reset();
net.fit(iterator);
if((i+1) % plotFrequency == 0) networkPredictions[i/ plotFrequency] = net.output(x, false);
}
//Plot the target data and the network predictions
plot(fn,x,fn.getFunctionValues(x),networkPredictions);
}
示例11: getFirstBatchFeatures
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Get a peak at the features of the {@code iterator}'s first batch using the given instances.
*
* @return Features of the first batch
* @throws Exception
*/
protected INDArray getFirstBatchFeatures(Instances data) throws Exception {
final DataSetIterator it = getDataSetIterator(data, CacheMode.NONE);
if (!it.hasNext()) {
throw new RuntimeException("Iterator was unexpectedly empty.");
}
final INDArray features = it.next().getFeatures();
it.reset();
return features;
}
示例12: main
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* args[0] input: word2vecファイル名
* args[1] input: 学習モデル名
* args[2] input: train/test親フォルダ名
* args[3] output: 学習モデル名
*
* @param args
* @throws Exception
*/
public static void main (final String[] args) throws Exception {
if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
System.exit(1);
WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
int batchSize = 16;//100;
int testBatch = 64;
int nEpochs = 1;
System.out.println("Starting online training");
DataSetIterator train = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
DataSetIterator test = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
for( int i=0; i<nEpochs; i++ ){
model.fit(train);
train.reset();
System.out.println("Epoch " + i + " complete. Starting evaluation:");
Evaluation evaluation = new Evaluation();
while(test.hasNext()) {
DataSet t = test.next();
INDArray features = t.getFeatures();
INDArray lables = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = model.output(features,false,inMask,outMask);
evaluation.evalTimeSeries(lables,predicted,outMask);
}
test.reset();
System.out.println(evaluation.stats());
System.out.println("Save model");
ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
}
}
示例13: testRocMultiToHtml
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testRocMultiToHtml() throws Exception {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
for (int numSteps : new int[] {20, 0}) {
ROCMultiClass roc = new ROCMultiClass(numSteps);
iter.reset();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
roc.eval(l, out);
String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
System.out.println(str);
}
}
示例14: testCNNMLNPretrain
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testCNNMLNPretrain() throws Exception {
// Note CNN does not do pretrain
int numSamples = 10;
int batchSize = 10;
DataSetIterator mnistIter = new MnistDataSetIterator(batchSize, numSamples, true);
MultiLayerNetwork model = getCNNMLNConfig(false, true);
model.fit(mnistIter);
mnistIter.reset();
MultiLayerNetwork model2 = getCNNMLNConfig(false, true);
model2.fit(mnistIter);
mnistIter.reset();
DataSet test = mnistIter.next();
Evaluation eval = new Evaluation();
INDArray output = model.output(test.getFeatureMatrix());
eval.eval(test.getLabels(), output);
double f1Score = eval.f1();
Evaluation eval2 = new Evaluation();
INDArray output2 = model2.output(test.getFeatureMatrix());
eval2.eval(test.getLabels(), output2);
double f1Score2 = eval2.f1();
assertEquals(f1Score, f1Score2, 1e-4);
}
示例15: testOutput
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
@Test
public void testOutput() throws Exception {
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER).seed(12345L).list()
.layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(50).nOut(10).build())
.pretrain(false).backprop(true).setInputType(InputType.convolutional(28, 28, 1)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator fullData = new MnistDataSetIterator(1, 2);
net.fit(fullData);
fullData.reset();
DataSet expectedSet = fullData.next(2);
INDArray expectedOut = net.output(expectedSet.getFeatureMatrix(), false);
fullData.reset();
INDArray actualOut = net.output(fullData);
assertEquals(expectedOut, actualOut);
}