当前位置: 首页>>代码示例>>Java>>正文


Java Model类代码示例

本文整理汇总了Java中org.deeplearning4j.nn.api.Model的典型用法代码示例。如果您正苦于以下问题:Java Model类的具体用法?Java Model怎么用?Java Model使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


Model类属于org.deeplearning4j.nn.api包,在下文中一共展示了Model类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: fromFile

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
public static DLModel fromFile(File file) throws Exception {
	Model model = null;
	try {
		System.out.println("Trying to load file as computation graph: " + file);
		model = ModelSerializer.restoreComputationGraph(file);
		System.out.println("Loaded Computation Graph.");
	} catch (Exception e) {
		try {
			System.out.println("Failed to load computation graph. Trying to load model.");
			model = ModelSerializer.restoreMultiLayerNetwork(file);
			System.out.println("Loaded Multilayernetwork");
		} catch (Exception e1) {
			System.out.println("Give up trying to load file: " + file);
			throw e;
		}
	}
	return new DLModel(model);
}
 
开发者ID:jesuino,项目名称:java-ml-projects,代码行数:19,代码来源:DLModel.java

示例2: onEpochEnd

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void onEpochEnd(Model model) {
  currentEpoch++;

  // Skip if this is not an evaluation epoch
  if (currentEpoch % n != 0) {
    return;
  }

  String s = "Epoch [" + currentEpoch + "/" + numEpochs + "]\n";

  if (enableIntermediateEvaluations) {
    s += "Train Set:      \n" + evaluateDataSetIterator(model, trainIterator, true);
    if (validationIterator != null) {
      s += "Validation Set: \n" + evaluateDataSetIterator(model, validationIterator, false);
    }
  }

  log(s);
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:21,代码来源:EpochListener.java

示例3: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int i) {
    if (printIterations <= 0)
        printIterations = 1;
    if (iterCount % printIterations == 0) {
        iter.reset();
        double cost = 0;
        double count = 0;
        while(iter.hasNext()) {
            DataSet minibatch = iter.next(miniBatchSize);
            cost += ((MultiLayerNetwork)model).scoreExamples(minibatch, false).sumNumber().doubleValue();
            count += minibatch.getLabelsMaskArray().sumNumber().doubleValue();
        }
        log.info(String.format("Iteration %5d test set score: %.4f", iterCount, cost/count));
    }
    iterCount++;
}
 
开发者ID:jpatanooga,项目名称:strata-2016-nyc-dl4j-rnn,代码行数:18,代码来源:HeldoutScoreIterationListener.java

示例4: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int iteration, int epoch) {
    //Check per-iteration termination conditions
    double latestScore = model.score();
    trainer.setLatestScore(latestScore);
    for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
        if (c.terminate(latestScore)) {
            trainer.setTermination(true);
            trainer.setTerminationReason(c);
            break;
        }
    }
    if (trainer.getTermination()) {
        // use built-in kill switch to stop fit operation
        wrapper.stopFit();
    }

    trainer.incrementIteration();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:EarlyStoppingParallelTrainer.java

示例5: testListenersForModel

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private static void testListenersForModel(Model model, List<IterationListener> listeners) {

        int nWorkers = 2;
        ParallelWrapper wrapper = new ParallelWrapper.Builder(model).workers(nWorkers).averagingFrequency(1)
                        .reportScoreAfterAveraging(true).build();

        if (listeners != null) {
            wrapper.setListeners(listeners);
        }

        List<DataSet> data = new ArrayList<>();
        for (int i = 0; i < nWorkers; i++) {
            data.add(new DataSet(Nd4j.rand(1, 10), Nd4j.rand(1, 10)));
        }

        DataSetIterator iter = new ExistingDataSetIterator(data);

        TestListener.clearCounts();
        wrapper.fit(iter);

        assertEquals(2, TestListener.workerIDs.size());
        assertEquals(1, TestListener.sessionIDs.size());
        assertEquals(2, TestListener.forwardPassCount.get());
        assertEquals(2, TestListener.backwardPassCount.get());
    }
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:TestListeners.java

示例6: updateGradientAccordingToParams

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) {
    if (model instanceof ComputationGraph) {
        ComputationGraph graph = (ComputationGraph) model;
        if (computationGraphUpdater == null) {
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                computationGraphUpdater = new ComputationGraphUpdater(graph);
            }
        }
        computationGraphUpdater.update(gradient, getIterationCount(model), getEpochCount(model), batchSize);
    } else {
        if (updater == null) {
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                updater = UpdaterCreator.getUpdater(model);
            }
        }
        Layer layer = (Layer) model;

        updater.update(layer, gradient, getIterationCount(model), getEpochCount(model), batchSize);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:BaseOptimizer.java

示例7: onForwardPass

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
    int iterCount = getModelInfo(model).iterCount;
    if (calcFromActivations() && updateConfig.reportingFrequency() > 0
            && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
        if (updateConfig.collectHistograms(StatsType.Activations)) {
            activationHistograms = getHistograms(activations, updateConfig.numHistogramBins(StatsType.Activations));
        }
        if (updateConfig.collectMean(StatsType.Activations)) {
            meanActivations = calculateSummaryStats(activations, StatType.Mean);
        }
        if (updateConfig.collectStdev(StatsType.Activations)) {
            stdevActivations = calculateSummaryStats(activations, StatType.Stdev);
        }
        if (updateConfig.collectMeanMagnitudes(StatsType.Activations)) {
            meanMagActivations = calculateSummaryStats(activations, StatType.MeanMagnitude);
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:BaseStatsListener.java

示例8: onGradientCalculation

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void onGradientCalculation(Model model) {
    int iterCount = getModelInfo(model).iterCount;
    if (calcFromGradients() && updateConfig.reportingFrequency() > 0
            && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
        Gradient g = model.gradient();
        if (updateConfig.collectHistograms(StatsType.Gradients)) {
            gradientHistograms = getHistograms(g.gradientForVariable(), updateConfig.numHistogramBins(StatsType.Gradients));
        }

        if (updateConfig.collectMean(StatsType.Gradients)) {
            meanGradients = calculateSummaryStats(g.gradientForVariable(), StatType.Mean);
        }
        if (updateConfig.collectStdev(StatsType.Gradients)) {
            stdevGradient = calculateSummaryStats(g.gradientForVariable(), StatType.Stdev);
        }
        if (updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
            meanMagGradients = calculateSummaryStats(g.gradientForVariable(), StatType.MeanMagnitude);
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:BaseStatsListener.java

示例9: configureListeners

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private void configureListeners(Model m, int counter) {
    if (iterationListeners != null) {
        List<IterationListener> list = new ArrayList<>(iterationListeners.size());
        for (IterationListener l : iterationListeners) {
            if (listenerRouterProvider != null && l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                rl.setStorageRouter(listenerRouterProvider.getRouter());
                String workerID = UIDProvider.getJVMUID() + "_" + counter;
                rl.setWorkerID(workerID);
            }
            list.add(l); //Don't need to clone listeners: not from broadcast, so deserialization handles
        }
        if (m instanceof MultiLayerNetwork)
            ((MultiLayerNetwork) m).setListeners(list);
        else
            ((ComputationGraph) m).setListeners(list);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ParameterAveragingTrainingWorker.java

示例10: testLoadNormalizersFile

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Test
public void testLoadNormalizersFile() throws Exception {
    MultiLayerNetwork net = getNetwork();

    File tempFile = File.createTempFile("tsfs", "fdfsdf");
    tempFile.deleteOnExit();

    ModelSerializer.writeModel(net, tempFile, true);

    NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
    ModelSerializer.addNormalizerToModel(tempFile, normalizer);
    Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
    Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath());
    assertEquals(model, net);
    assertEquals(normalizer, normalizer1);

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ModelGuesserTest.java

示例11: testLoadNormalizersInputStream

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Test
public void testLoadNormalizersInputStream() throws Exception {
    MultiLayerNetwork net = getNetwork();

    File tempFile = File.createTempFile("tsfs", "fdfsdf");
    tempFile.deleteOnExit();

    ModelSerializer.writeModel(net, tempFile, true);

    NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
    ModelSerializer.addNormalizerToModel(tempFile, normalizer);
    Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
    try (InputStream inputStream = new FileInputStream(tempFile)) {
        Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(inputStream);
        assertEquals(model, net);
        assertEquals(normalizer, normalizer1);
    }

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:ModelGuesserTest.java

示例12: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(final Model model, final int iteration) {
    runOnUiThread(new Runnable() {
        @Override
        public void run() {
            if (iteration % 100 == 0) {
                double result = model.score();
                String message = "\nScore at iteration " + iteration + " is " + result;
                Log.d(TAG, message);

                loggingArea.append(message);
            }
        }
    });
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:16,代码来源:MainActivity.java

示例13: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int iteration) {
    if(m_printIterations <= 0)
        m_printIterations = 1;
    if(m_iterCount % m_printIterations == 0) {
        invoke();
        double result = model.score();
        m_progressBar.printProgress("Iteration: " + m_iterCount + ", Score: " + result);
    }
    m_iterCount++;
}
 
开发者ID:braeunlich,项目名称:anagnostes,代码行数:12,代码来源:TrainProgressIterationListener.java

示例14: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone (Model model,
                           int iteration)
{
    iterCount++;

    if ((iterCount % constants.listenerPeriod.getValue()) == 0) {
        invoke();

        final double score = model.score();
        final int count = (int) iterCount;
        logger.info(String.format("Score at iteration %d is %.5f", count, score));
        display(epoch, count, score);
    }
}
 
开发者ID:Audiveris,项目名称:audiveris,代码行数:16,代码来源:TrainingPanel.java

示例15: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int i) {
    if (printIterations <= 0)
        printIterations = 1;
    if (iterCount % printIterations == 0) {
        saveModel((MultiLayerNetwork)model, this.modelSavePath);
    }
    iterCount++;
}
 
开发者ID:jpatanooga,项目名称:strata-2016-nyc-dl4j-rnn,代码行数:10,代码来源:ModelSaver.java


注:本文中的org.deeplearning4j.nn.api.Model类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。