当前位置: 首页>>代码示例>>Java>>正文


Java MultiLayerNetwork.fit方法代码示例

本文整理汇总了Java中org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit方法的典型用法代码示例。如果您正苦于以下问题:Java MultiLayerNetwork.fit方法的具体用法?Java MultiLayerNetwork.fit怎么用?Java MultiLayerNetwork.fit使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.nn.multilayer.MultiLayerNetwork的用法示例。


在下文中一共展示了MultiLayerNetwork.fit方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testUpdaters

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
@Test
public void testUpdaters() {
    SparkDl4jMultiLayer sparkNet = getBasicNetwork();
    MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();

    netCopy.fit(data);
    Updater expectedUpdater = netCopy.conf().getLayer().getUpdater();
    double expectedLR = netCopy.conf().getLayer().getLearningRate();
    double expectedMomentum = netCopy.conf().getLayer().getMomentum();

    Updater actualUpdater = sparkNet.getNetwork().conf().getLayer().getUpdater();
    sparkNet.fit(sparkData);
    double actualLR = sparkNet.getNetwork().conf().getLayer().getLearningRate();
    double actualMomentum = sparkNet.getNetwork().conf().getLayer().getMomentum();

    assertEquals(expectedUpdater, actualUpdater);
    assertEquals(expectedLR, actualLR, 0.01);
    assertEquals(expectedMomentum, actualMomentum, 0.01);

}
 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-Hadoop,代码行数:21,代码来源:TestSparkMultiLayerParameterAveraging.java

示例2: testUpdaters

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
@Test
public void testUpdaters() {
    SparkDl4jMultiLayer sparkNet = getBasicNetwork();
    MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();

    netCopy.fit(data);
    IUpdater expectedUpdater = ((BaseLayer) netCopy.conf().getLayer()).getIUpdater();
    double expectedLR = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getLearningRate();
    double expectedMomentum = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getMomentum();

    IUpdater actualUpdater = ((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater();
    sparkNet.fit(sparkData);
    double actualLR = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getLearningRate();
    double actualMomentum = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getMomentum();

    assertEquals(expectedUpdater, actualUpdater);
    assertEquals(expectedLR, actualLR, 0.01);
    assertEquals(expectedMomentum, actualMomentum, 0.01);

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:TestSparkMultiLayerParameterAveraging.java

示例3: testOptimizersBasicMLPBackprop

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
@Test
public void testOptimizersBasicMLPBackprop() {
    //Basic tests of the 'does it throw an exception' variety.

    DataSetIterator iter = new IrisDataSetIterator(5, 50);

    OptimizationAlgorithm[] toTest =
                    {OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT,
                                    OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS
                    //OptimizationAlgorithm.HESSIAN_FREE	//Known to not work
                    };

    for (OptimizationAlgorithm oa : toTest) {
        MultiLayerNetwork network = new MultiLayerNetwork(getMLPConfigIris(oa));
        network.init();

        iter.reset();
        network.fit(iter);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:TestOptimizers.java

示例4: main

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
public static void main(final String[] args){

        //Switch these two options to do different functions with different networks
        final MathFunction fn = new SinXDivXMathFunction();
        final MultiLayerConfiguration conf = getDeepDenseLayerNetworkConfiguration();

        //Generate the training data
        final INDArray x = Nd4j.linspace(-10,10,nSamples).reshape(nSamples, 1);
        final DataSetIterator iterator = getTrainingData(x,fn,batchSize,rng);

        //Create the network
        final MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        net.setListeners(new ScoreIterationListener(1));


        //Train the network on the full data set, and evaluate in periodically
        final INDArray[] networkPredictions = new INDArray[nEpochs/ plotFrequency];
        for( int i=0; i<nEpochs; i++ ){
            iterator.reset();
            net.fit(iterator);
            if((i+1) % plotFrequency == 0) networkPredictions[i/ plotFrequency] = net.output(x, false);
        }

        //Plot the target data and the network predictions
        plot(fn,x,fn.getFunctionValues(x),networkPredictions);
    }
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:28,代码来源:RegressionMathFunctions.java

示例5: main

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
public static void main(String[] args) throws Exception {
    final int numRows = 28;
    final int numColumns = 28;
    int seed = 123;
    int numSamples = MnistDataFetcher.NUM_EXAMPLES;
    int batchSize = 1000;
    int iterations = 1;
    int listenerFreq = iterations/5;

    log.info("Load data....");
    DataSetIterator iter = new MnistDataSetIterator(batchSize,numSamples,true);

    log.info("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .iterations(iterations)
            .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
            .list(8)
            .layer(0, new RBM.Builder().nIn(numRows * numColumns).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(1, new RBM.Builder().nIn(2000).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(2, new RBM.Builder().nIn(1000).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(3, new RBM.Builder().nIn(500).nOut(30).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(4, new RBM.Builder().nIn(30).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) 
            .layer(5, new RBM.Builder().nIn(500).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(6, new RBM.Builder().nIn(1000).nOut(2000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
            .layer(7, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nIn(2000).nOut(numRows*numColumns).build())
            .pretrain(true).backprop(true)
            .build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    model.setListeners(new ScoreIterationListener(listenerFreq));

    log.info("Train model....");
    while(iter.hasNext()) {
        DataSet next = iter.next();
        model.fit(new DataSet(next.getFeatureMatrix(),next.getFeatureMatrix()));
    }
}
 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-Hadoop,代码行数:41,代码来源:DeepAutoEncoder.java

示例6: main

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
public static void main(String[] args){

        //Generate the training data
        DataSetIterator iterator = getTrainingData(batchSize,rng);

        //Create the network
        int numInput = 2;
        int numOutputs = 1;
        int nHidden = 10;
        MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .weightInit(WeightInit.XAVIER)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list()
                .layer(0, new DenseLayer.Builder().nIn(numInput).nOut(nHidden)
                        .activation(Activation.TANH)
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
                        .activation(Activation.IDENTITY)
                        .nIn(nHidden).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build()
        );
        net.init();
        net.setListeners(new ScoreIterationListener(1));


        //Train the network on the full data set, and evaluate in periodically
        for( int i=0; i<nEpochs; i++ ){
            iterator.reset();
            net.fit(iterator);
        }
        // Test the addition of 2 numbers (Try different numbers here)
        final INDArray input = Nd4j.create(new double[] { 0.111111, 0.3333333333333 }, new int[] { 1, 2 });
        INDArray out = net.output(input, false);
        System.out.println(out);

    }
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:41,代码来源:RegressionSum.java

示例7: trainNet

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
private static MultiLayerNetwork trainNet(MultiLayerNetwork net, List<INDArray> featuresTrain) {
	int nEpochs = 5;
	for( int epoch = 0; epoch < nEpochs; epoch++ ){
		for(INDArray data : featuresTrain){
			net.fit(data, data);
		}
	}

	return net;
}
 
开发者ID:matthiaszimmermann,项目名称:ml_demo,代码行数:11,代码来源:MammographyAutoencoder.java

示例8: main

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: 学習モデル名
 * args[2] input: train/test親フォルダ名
 * args[3] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 1;

  System.out.println("Starting online training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    System.out.println("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    System.out.println(evaluation.stats());

    System.out.println("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:47,代码来源:SentimentRecurrentTrainOnlineCmd.java

示例9: testCNNDBNMultiLayer

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
@Test
public void testCNNDBNMultiLayer() throws Exception {
    DataSetIterator iter = new MnistDataSetIterator(2, 2);
    DataSet next = iter.next();

    // Run with separate activation layer
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123)
                    .weightInit(WeightInit.XAVIER).list()
                    .layer(0, new ConvolutionLayer.Builder(new int[] {1, 1}, new int[] {1, 1}).nIn(1).nOut(6)
                                    .activation(Activation.IDENTITY).build())
                    .layer(1, new BatchNormalization.Builder().build())
                    .layer(2, new ActivationLayer.Builder().activation(Activation.RELU).build())
                    .layer(3, new DenseLayer.Builder().nIn(28 * 28 * 6).nOut(10).activation(Activation.IDENTITY)
                                    .build())
                    .layer(4, new BatchNormalization.Builder().nOut(10).build())
                    .layer(5, new ActivationLayer.Builder().activation(Activation.RELU).build())
                    .layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nOut(10).build())
                    .backprop(true).pretrain(false).setInputType(InputType.convolutional(28, 28, 1)).build();

    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();

    network.setInput(next.getFeatureMatrix());
    INDArray activationsActual = network.preOutput(next.getFeatureMatrix());
    assertEquals(10, activationsActual.shape()[1], 1e-2);

    network.fit(next);
    INDArray actualGammaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.GAMMA);
    INDArray actualBetaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.BETA);
    assertTrue(actualGammaParam != null);
    assertTrue(actualBetaParam != null);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:35,代码来源:ConvolutionLayerSetupTest.java

示例10: testGateActivationFnsSanityCheck

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
@Test
public void testGateActivationFnsSanityCheck() {
    for (String gateAfn : new String[] {"sigmoid", "hardsigmoid"}) {

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                        .seed(12345).list()
                        .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder()
                                        .gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2)
                                        .build())
                        .layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder()
                                        .lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2)
                                        .activation(Activation.TANH).build())
                        .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesLSTM) net.getLayer(0).conf().getLayer())
                        .getGateActivationFn().toString());

        INDArray in = Nd4j.rand(new int[] {3, 2, 5});
        INDArray labels = Nd4j.rand(new int[] {3, 2, 5});

        net.fit(in, labels);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:28,代码来源:GravesLSTMTest.java

示例11: testMLPMultiLayerPretrain

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
@Test
public void testMLPMultiLayerPretrain() {
    // Note CNN does not do pretrain
    MultiLayerNetwork model = getDenseMLNConfig(false, true);
    model.fit(iter);

    MultiLayerNetwork model2 = getDenseMLNConfig(false, true);
    model2.fit(iter);
    iter.reset();

    DataSet test = iter.next();

    assertEquals(model.params(), model2.params());

    Evaluation eval = new Evaluation();
    INDArray output = model.output(test.getFeatureMatrix());
    eval.eval(test.getLabels(), output);
    double f1Score = eval.f1();

    Evaluation eval2 = new Evaluation();
    INDArray output2 = model2.output(test.getFeatureMatrix());
    eval2.eval(test.getLabels(), output2);
    double f1Score2 = eval2.f1();

    assertEquals(f1Score, f1Score2, 1e-4);

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:28,代码来源:DenseTest.java

示例12: checkSerialization

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
@Test
public void checkSerialization() throws Exception {
    //Serialize the batch norm network (after training), and make sure we get same activations out as before
    // i.e., make sure state is properly stored

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .seed(12345)
                    .list()
                    .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER)
                                    .activation(Activation.IDENTITY).build())
                    .layer(1, new BatchNormalization.Builder().build())
                    .layer(2, new ActivationLayer.Builder().activation(Activation.LEAKYRELU).build())
                    .layer(3, new DenseLayer.Builder().nOut(10).activation(Activation.LEAKYRELU).build())
                    .layer(4, new BatchNormalization.Builder().build())
                    .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build())
                    .backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
    for (int i = 0; i < 20; i++) {
        net.fit(iter.next());
    }

    INDArray in = iter.next().getFeatureMatrix();

    INDArray out = net.output(in, false);
    INDArray out2 = net.output(in, false);

    assertEquals(out, out2);

    MultiLayerNetwork net2 = TestUtils.testModelSerialization(net);

    INDArray outDeser = net2.output(in, false);

    assertEquals(out, outDeser);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:40,代码来源:BatchNormalizationTest.java

示例13: main

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
public static void main(String[] args) throws Exception {
		final int numRows = 28;
		final int numColumns = 28;
		int outputNum = 10;
		int numSamples = 60000;
		int batchSize = 100;
		int iterations = 10;
		int seed = 123;
		int listenerFreq = batchSize / 5;

		log.info("Load data....");
		DataSetIterator iter = new MnistDataSetIterator(batchSize, numSamples,
				true);

		log.info("Build model....");
		 MultiLayerNetwork model = softMaxRegression(seed, iterations, numRows, numColumns, outputNum);
//		// MultiLayerNetwork model = deepBeliefNetwork(seed, iterations,
//		// numRows, numColumns, outputNum);
//		MultiLayerNetwork model = deepConvNetwork(seed, iterations, numRows,
//				numColumns, outputNum);

		model.init();
		model.setListeners(Collections
				.singletonList((IterationListener) new ScoreIterationListener(
						listenerFreq)));

		log.info("Train model....");
		model.fit(iter); // achieves end to end pre-training

		log.info("Evaluate model....");
		Evaluation eval = new Evaluation(outputNum);

		DataSetIterator testIter = new MnistDataSetIterator(100, 10000);
		while (testIter.hasNext()) {
			DataSet testMnist = testIter.next();
			INDArray predict2 = model.output(testMnist.getFeatureMatrix());
			eval.eval(testMnist.getLabels(), predict2);
		}

		log.info(eval.stats());
		log.info("****************Example finished********************");

	}
 
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:44,代码来源:NeuralNetworks.java

示例14: main

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
public static void main(String[] args) throws Exception {
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
    recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));

    int labelIndex = 4;
    int numClasses = 3;
    int batchSize = 150;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();
    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);

    DataSet trainingData = testAndTrain.getTrain();
    DataSet testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);//Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set

    final int numInputs = 4;
    int outputNum = 3;
    int iterations = 1000;
    long seed = 6;
    int epochs = 100;

    log.info("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .iterations(iterations)
            .weightInit(WeightInit.XAVIER)
            .learningRate(0.1)
            .regularization(true)
            .l2(1e-4)
            .list()
            .layer(0, new DenseLayer.Builder()
                    .nIn(numInputs)
                    .nOut(20)
                    .activation(Activation.TANH)
                    .build())
            .layer(1, new DenseLayer.Builder()
                    .nOut(10)
                    .activation(Activation.TANH)
                    .build())
            .layer(2, new OutputLayer.Builder()
                    .lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                    .activation(Activation.SOFTMAX)
                    .nOut(outputNum)
                    .build())
            .backprop(true)
            .pretrain(false)
            .build();

    //run the model
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(new ScoreIterationListener(100));
    for (int epoch = 0; epoch < epochs; epoch++) {
        model.fit(trainingData);
    }

    //evaluate the model on the test set
    Evaluation eval = new Evaluation(numClasses);
    INDArray output = model.output(testData.getFeatureMatrix());
    eval.eval(testData.getLabels(), output);
    log.info(eval.stats());
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:72,代码来源:IrisClassification.java

示例15: DeepAutoEncoderExample

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //导入方法依赖的package包/类
public DeepAutoEncoderExample() {
    try {
        int seed = 123;
        int numberOfIterations = 1;
        iterator = new MnistDataSetIterator(1000, MnistDataFetcher.NUM_EXAMPLES, true);
        
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(numberOfIterations)
                .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT)
                .list()
                .layer(0, new RBM.Builder().nIn(numberOfRows * numberOfColumns)
                        .nOut(1000)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(1, new RBM.Builder().nIn(1000).nOut(500)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(2, new RBM.Builder().nIn(500).nOut(250)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(3, new RBM.Builder().nIn(250).nOut(100)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(4, new RBM.Builder().nIn(100).nOut(30)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //encoding stops
                .layer(5, new RBM.Builder().nIn(30).nOut(100)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //decoding starts
                .layer(6, new RBM.Builder().nIn(100).nOut(250)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(7, new RBM.Builder().nIn(250).nOut(500)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(8, new RBM.Builder().nIn(500).nOut(1000)
                        .lossFunction(LossFunctions.LossFunction.RMSE_XENT).build())
                .layer(9, new OutputLayer.Builder(
                                LossFunctions.LossFunction.RMSE_XENT).nIn(1000)
                        .nOut(numberOfRows * numberOfColumns).build())
                .pretrain(true).backprop(true)
                .build();

        model = new MultiLayerNetwork(conf);
        model.init();

        model.setListeners(Collections.singletonList(
                (IterationListener) new ScoreIterationListener()));

        while (iterator.hasNext()) {
            DataSet dataSet = iterator.next();
            model.fit(new DataSet(dataSet.getFeatureMatrix(),
                    dataSet.getFeatureMatrix()));
        }

        modelFile = new File("savedModel");
        ModelSerializer.writeModel(model, modelFile, true);
    } catch (IOException ex) {
        ex.printStackTrace();
    }
}
 
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:55,代码来源:DeepAutoEncoderExample.java


注:本文中的org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。