当前位置: 首页>>代码示例>>Java>>正文


Java StatsListener类代码示例

本文整理汇总了Java中org.deeplearning4j.ui.stats.StatsListener的典型用法代码示例。如果您正苦于以下问题:Java StatsListener类的具体用法?Java StatsListener怎么用?Java StatsListener使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


StatsListener类属于org.deeplearning4j.ui.stats包,在下文中一共展示了StatsListener类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: reportStorageEvents

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
@Override
public synchronized void reportStorageEvents(Collection<StatsStorageEvent> events) {
    for (StatsStorageEvent sse : events) {
        if (StatsListener.TYPE_ID.equals(sse.getTypeID())) {
            if (sse.getEventType() == StatsStorageListener.EventType.PostStaticInfo
                            && StatsListener.TYPE_ID.equals(sse.getTypeID())) {
                knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
            }

            Long lastUpdate = lastUpdateForSession.get(sse.getSessionID());
            if (lastUpdate == null) {
                lastUpdateForSession.put(sse.getSessionID(), sse.getTimestamp());
            } else if (sse.getTimestamp() > lastUpdate) {
                lastUpdateForSession.put(sse.getSessionID(), sse.getTimestamp()); //Should be thread safe - read only elsewhere
            }
        }
    }

    if (currentSessionID == null)
        getDefaultSession();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:TrainModule.java

示例2: getDefaultSession

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
private void getDefaultSession() {
    if (currentSessionID != null)
        return;

    long mostRecentTime = Long.MIN_VALUE;
    String sessionID = null;
    for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
        List<Persistable> staticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), StatsListener.TYPE_ID);
        if (staticInfos == null || staticInfos.isEmpty())
            continue;
        Persistable p = staticInfos.get(0);
        long thisTime = p.getTimeStamp();
        if (thisTime > mostRecentTime) {
            mostRecentTime = thisTime;
            sessionID = entry.getKey();
        }
    }

    if (sessionID != null) {
        currentSessionID = sessionID;
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:TrainModule.java

示例3: getModelGraph

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
private Result getModelGraph() {


        boolean noData = currentSessionID == null;
        StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
        List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
                        : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));

        if (allStatic.isEmpty()) {
            return ok();
        }

        TrainModuleUtils.GraphInfo gi = getGraphInfo();
        if (gi == null)
            return ok();
        return ok(Json.toJson(gi));
    }
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:TrainModule.java

示例4: testRemoteFull

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
@Test
@Ignore
public void testRemoteFull() throws Exception {
    //Use this in conjunction with startRemoteUI()

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
                    .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
                    .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
                    .pretrain(false).backprop(true).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    StatsStorageRouter ssr = new RemoteUIStatsStorageRouter("http://localhost:9000");
    net.setListeners(new StatsListener(ssr), new ScoreIterationListener(1));

    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    for (int i = 0; i < 500; i++) {
        net.fit(iter);
        //            Thread.sleep(100);
        Thread.sleep(100);
    }

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:TestRemoteReceiver.java

示例5: testListenersViaModel

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
@Test
public void testListenersViaModel() {
    TestListener.clearCounts();

    MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0,
                    new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10)
                                    .activation(Activation.TANH).build());

    MultiLayerConfiguration conf = builder.build();
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    StatsStorage ss = new InMemoryStatsStorage();
    model.setListeners(new TestListener(), new StatsListener(ss));

    testListenersForModel(model, null);

    assertEquals(1, ss.listSessionIDs().size());
    assertEquals(2, ss.listWorkerIDsForSession(ss.listSessionIDs().get(0)).size());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:TestListeners.java

示例6: testListenersViaModelGraph

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
@Test
public void testListenersViaModelGraph() {
    TestListener.clearCounts();

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder()
                    .addInputs("in").addLayer("0",
                                    new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10)
                                                    .activation(Activation.TANH).build(),
                                    "in")
                    .setOutputs("0").build();

    ComputationGraph model = new ComputationGraph(conf);
    model.init();

    StatsStorage ss = new InMemoryStatsStorage();
    model.setListeners(new TestListener(), new StatsListener(ss));

    testListenersForModel(model, null);

    assertEquals(1, ss.listSessionIDs().size());
    assertEquals(2, ss.listWorkerIDsForSession(ss.listSessionIDs().get(0)).size());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:TestListeners.java

示例7: onAttach

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
@Override
public synchronized void onAttach(StatsStorage statsStorage) {
    for (String sessionID : statsStorage.listSessionIDs()) {
        for (String typeID : statsStorage.listTypeIDsForSession(sessionID)) {
            if (!StatsListener.TYPE_ID.equals(typeID))
                continue;
            knownSessionIDs.put(sessionID, statsStorage);
        }
    }

    if (currentSessionID == null)
        getDefaultSession();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:14,代码来源:TrainModule.java

示例8: getWorkerIdForIndex

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
private synchronized String getWorkerIdForIndex(int workerIdx) {
    String sid = currentSessionID;
    if (sid == null)
        return null;

    Map<Integer, String> idxToId = workerIdxToName.get(sid);
    if (idxToId == null) {
        idxToId = Collections.synchronizedMap(new HashMap<>());
        workerIdxToName.put(sid, idxToId);
    }

    if (idxToId.containsKey(workerIdx)) {
        return idxToId.get(workerIdx);
    }

    //Need to record new worker...
    //Get counter
    AtomicInteger counter = workerIdxCount.get(sid);
    if (counter == null) {
        counter = new AtomicInteger(0);
        workerIdxCount.put(sid, counter);
    }

    //Get all worker IDs
    StatsStorage ss = knownSessionIDs.get(sid);
    List<String> allWorkerIds = new ArrayList<>(ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID));
    Collections.sort(allWorkerIds);

    //Ensure all workers have been assigned an index
    for (String s : allWorkerIds) {
        if (idxToId.containsValue(s))
            continue;
        //Unknown worker ID:
        idxToId.put(counter.getAndIncrement(), s);
    }

    //May still return null if index is wrong/too high...
    return idxToId.get(workerIdx);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:40,代码来源:TrainModule.java

示例9: testUIMultipleSessions

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
@Test
@Ignore
public void testUIMultipleSessions() throws Exception {

    for (int session = 0; session < 3; session++) {

        StatsStorage ss = new InMemoryStatsStorage();

        UIServer uiServer = UIServer.getInstance();
        uiServer.attach(ss);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
                        .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
                        .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
                                        .activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
                        .pretrain(false).backprop(true).build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));

        DataSetIterator iter = new IrisDataSetIterator(150, 150);

        for (int i = 0; i < 20; i++) {
            net.fit(iter);
            Thread.sleep(100);
        }
    }


    Thread.sleep(1000000);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:34,代码来源:TestPlayUI.java

示例10: testUICompGraph

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
@Test
@Ignore
public void testUICompGraph() throws Exception {

    StatsStorage ss = new InMemoryStatsStorage();

    UIServer uiServer = UIServer.getInstance();
    uiServer.attach(ss);

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
                    .addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(),
                                    "in")
                    .addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0")
                    .pretrain(false).backprop(true).setOutputs("L1").build();

    ComputationGraph net = new ComputationGraph(conf);
    net.init();

    net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));

    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    for (int i = 0; i < 100; i++) {
        net.fit(iter);
        Thread.sleep(100);
    }

    Thread.sleep(100000);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:31,代码来源:TestPlayUI.java

示例11: testParallelStatsListenerCompatibility

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
@Test
@Ignore //To be run manually
public void testParallelStatsListenerCompatibility() throws Exception {
    UIServer uiServer = UIServer.getInstance();

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
                    .layer(1, new OutputLayer.Builder().nIn(3).nOut(3)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);

    // it's important that the UI can report results from parallel training
    // there's potential for StatsListener to fail if certain properties aren't set in the model
    StatsStorage statsStorage = new InMemoryStatsStorage();
    net.setListeners(new StatsListener(statsStorage));
    uiServer.attach(statsStorage);

    DataSetIterator irisIter = new IrisDataSetIterator(50, 500);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(500))
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true))
                                    .evaluateEveryNEpochs(2).modelSaver(saver).build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer =
                    new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 3, 6, 2);

    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);

    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:37,代码来源:TestParallelEarlyStoppingUI.java

示例12: addStatsListener

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
public static void addStatsListener(Dl4jMlpClassifier clf, FileStatsStorage statsStorage) {
  clf.setIterationListener(new StatsListener(statsStorage));
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:4,代码来源:TestUtil.java

示例13: LSTMTrainer

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
/**
 * Constructor
 * @param trainingSet Text file containing several ABC music files
 * @throws IOException
 */
public LSTMTrainer(String trainingSet, int seed) throws IOException {
    lstmLayerSize_ = 200; // original 200
    batchSize_ = 32; // original 32
    truncatedBackPropThroughTimeLength_ = 50;
    nbEpochs_ = 100;
    learningRate_ = 0.04; // 0.1 original // best 0.05 3epochs
    generateSamplesEveryNMinibatches_ = 200;
    generationInitialization_ = "X";
    seed_ = seed;
    random_ = new Random(seed);
    output_ = null;

    trainingSetIterator_ = new ABCIterator(trainingSet, Charset.forName("ASCII"), batchSize_, random_);
    charToInt_ = trainingSetIterator_.getCharToInt();
    intToChar_ = trainingSetIterator_.getIntToChar();
    exampleLength_ = trainingSetIterator_.getExampleLength();

    int nOut = trainingSetIterator_.totalOutcomes();

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
            .learningRate(learningRate_)
            .rmsDecay(0.95) // 0.95 original
            .seed(seed_)
            .regularization(true) // true original
            .l2(0.001)
            .weightInit(WeightInit.XAVIER)
            .updater(Updater.RMSPROP)
            .list()
            .layer(0, new GravesLSTM.Builder().nIn(trainingSetIterator_.inputColumns()).nOut(lstmLayerSize_)
                    .activation("tanh").build())
            .layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize_).nOut(lstmLayerSize_)
                    .activation("tanh").build())
            .layer(2, new GravesLSTM.Builder().nIn(lstmLayerSize_).nOut(lstmLayerSize_)
                    .activation("tanh").build())
            .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax")
                    .nIn(lstmLayerSize_).nOut(nOut).build())
            .backpropType(BackpropType.TruncatedBPTT)
                .tBPTTForwardLength(truncatedBackPropThroughTimeLength_)
                .tBPTTBackwardLength(truncatedBackPropThroughTimeLength_)
            .pretrain(false).backprop(true)
            .build();

    lstmNet_ = new MultiLayerNetwork(conf);
    lstmNet_.init();
    //lstmNet_.setListeners(new ScoreIterationListener(1));
    //lstmNet_.setListeners(new HistogramIterationListener(1));
    UIServer uiServer = UIServer.getInstance();
    StatsStorage statsStorage = new InMemoryStatsStorage();
    uiServer.attach(statsStorage);
    lstmNet_.setListeners(new StatsListener(statsStorage));

    if (ExecutionParameters.verbose) {
        Layer[] layers = lstmNet_.getLayers();
        int totalNumParams = 0;
        for (int i = 0; i < layers.length; i++) {
            int nParams = layers[i].numParams();
            System.out.println("Number of parameters in layer " + i + ": " + nParams);
            totalNumParams += nParams;
        }
        System.out.println("Total number of network parameters: " + totalNumParams);
    }
}
 
开发者ID:paveyry,项目名称:LyreLand,代码行数:69,代码来源:LSTMTrainer.java

示例14: train

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
private static void train(CommandLine c) {
    int nEpochs = Integer.parseInt(c.getOptionValue("e"));
    String modelName = c.getOptionValue("o");
    DataIterator<NormalizerStandardize> it = DataIterator.irisCsv(c.getOptionValue("i"));
    RecordReaderDataSetIterator trainData = it.getIterator();
    NormalizerStandardize normalizer = it.getNormalizer();

    log.info("Data Loaded");

    MultiLayerConfiguration conf = net(4, 3);
    MultiLayerNetwork model = new MultiLayerNetwork(conf);

    model.init();

    UIServer uiServer = UIServer.getInstance();

    StatsStorage statsStorage = new InMemoryStatsStorage();
    uiServer.attach(statsStorage);
    model.setListeners(Arrays.asList(new ScoreIterationListener(1), new StatsListener(statsStorage)));

    for (int i = 0; i < nEpochs; i++) {
        log.info("Starting epoch {} of {}", i, nEpochs);

        while (trainData.hasNext()) {
            model.fit(trainData.next());
        }

        log.info("Finished epoch {}", i);
        trainData.reset();
    }

    try {
        ModelSerializer.writeModel(model, modelName, true);

        normalizer.save(
                new File(modelName + ".norm1"),
                new File(modelName + ".norm2"),
                new File(modelName + ".norm3"),
                new File(modelName + ".norm4")
        );
    } catch (IOException e) {
        e.printStackTrace();
    }

    log.info("Model saved to: {}", modelName);
}
 
开发者ID:wmeddie,项目名称:dl4j-trainer-archetype,代码行数:47,代码来源:Train.java

示例15: getCallbackTypeIDs

import org.deeplearning4j.ui.stats.StatsListener; //导入依赖的package包/类
@Override
public List<String> getCallbackTypeIDs() {
    return Collections.singletonList(StatsListener.TYPE_ID);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:5,代码来源:TrainModule.java


注:本文中的org.deeplearning4j.ui.stats.StatsListener类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。