当前位置: 首页>>代码示例>>Java>>正文


Java StatsStorage类代码示例

本文整理汇总了Java中org.deeplearning4j.api.storage.StatsStorage的典型用法代码示例。如果您正苦于以下问题:Java StatsStorage类的具体用法?Java StatsStorage怎么用?Java StatsStorage使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


StatsStorage类属于org.deeplearning4j.api.storage包,在下文中一共展示了StatsStorage类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: getDefaultSession

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private void getDefaultSession() {
    if (currentSessionID != null)
        return;

    long mostRecentTime = Long.MIN_VALUE;
    String sessionID = null;
    for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
        List<Persistable> staticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), StatsListener.TYPE_ID);
        if (staticInfos == null || staticInfos.isEmpty())
            continue;
        Persistable p = staticInfos.get(0);
        long thisTime = p.getTimeStamp();
        if (thisTime > mostRecentTime) {
            mostRecentTime = thisTime;
            sessionID = entry.getKey();
        }
    }

    if (sessionID != null) {
        currentSessionID = sessionID;
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:TrainModule.java

示例2: getModelGraph

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private Result getModelGraph() {


        boolean noData = currentSessionID == null;
        StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
        List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
                        : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));

        if (allStatic.isEmpty()) {
            return ok();
        }

        TrainModuleUtils.GraphInfo gi = getGraphInfo();
        if (gi == null)
            return ok();
        return ok(Json.toJson(gi));
    }
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:TrainModule.java

示例3: attach

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public synchronized void attach(StatsStorage statsStorage) {
    if (statsStorage == null)
        throw new IllegalArgumentException("StatsStorage cannot be null");
    if (statsStorageInstances.contains(statsStorage))
        return;
    StatsStorageListener listener = new QueueStatsStorageListener(eventQueue);
    listeners.add(new Pair<>(statsStorage, listener));
    statsStorage.registerStatsStorageListener(listener);
    statsStorageInstances.add(statsStorage);

    for (UIModule uiModule : uiModules) {
        uiModule.onAttach(statsStorage);
    }

    log.info("StatsStorage instance attached to UI: {}", statsStorage);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:PlayUIServer.java

示例4: detach

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public synchronized void detach(StatsStorage statsStorage) {
    if (statsStorage == null)
        throw new IllegalArgumentException("StatsStorage cannot be null");
    if (!statsStorageInstances.contains(statsStorage))
        return; //No op
    boolean found = false;
    for (Iterator<Pair<StatsStorage, StatsStorageListener>> iterator = listeners.iterator(); iterator.hasNext();) {
        Pair<StatsStorage, StatsStorageListener> p = iterator.next();
        if (p.getFirst() == statsStorage) { //Same object, not equality
            statsStorage.deregisterStatsStorageListener(p.getSecond());
            iterator.remove();
            found = true;
        }
    }
    for (UIModule uiModule : uiModules) {
        uiModule.onDetach(statsStorage);
    }
    if (found) {
        log.info("StatsStorage instance detached from UI: {}", statsStorage);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:PlayUIServer.java

示例5: testListenersViaModel

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
public void testListenersViaModel() {
    TestListener.clearCounts();

    MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0,
                    new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10)
                                    .activation(Activation.TANH).build());

    MultiLayerConfiguration conf = builder.build();
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    StatsStorage ss = new InMemoryStatsStorage();
    model.setListeners(new TestListener(), new StatsListener(ss));

    testListenersForModel(model, null);

    assertEquals(1, ss.listSessionIDs().size());
    assertEquals(2, ss.listWorkerIDsForSession(ss.listSessionIDs().get(0)).size());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:TestListeners.java

示例6: testListenersViaModelGraph

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
public void testListenersViaModelGraph() {
    TestListener.clearCounts();

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder()
                    .addInputs("in").addLayer("0",
                                    new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10)
                                                    .activation(Activation.TANH).build(),
                                    "in")
                    .setOutputs("0").build();

    ComputationGraph model = new ComputationGraph(conf);
    model.init();

    StatsStorage ss = new InMemoryStatsStorage();
    model.setListeners(new TestListener(), new StatsListener(ss));

    testListenersForModel(model, null);

    assertEquals(1, ss.listSessionIDs().size());
    assertEquals(2, ss.listWorkerIDsForSession(ss.listSessionIDs().get(0)).size());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:TestListeners.java

示例7: onAttach

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public synchronized void onAttach(StatsStorage statsStorage) {
    for (String sessionID : statsStorage.listSessionIDs()) {
        for (String typeID : statsStorage.listTypeIDsForSession(sessionID)) {
            if (!StatsListener.TYPE_ID.equals(typeID))
                continue;
            knownSessionIDs.put(sessionID, statsStorage);
        }
    }

    if (currentSessionID == null)
        getDefaultSession();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:14,代码来源:TrainModule.java

示例8: onDetach

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public void onDetach(StatsStorage statsStorage) {
    for (String s : knownSessionIDs.keySet()) {
        if (knownSessionIDs.get(s) == statsStorage) {
            knownSessionIDs.remove(s);
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:9,代码来源:TrainModule.java

示例9: getWorkerIdForIndex

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private synchronized String getWorkerIdForIndex(int workerIdx) {
    String sid = currentSessionID;
    if (sid == null)
        return null;

    Map<Integer, String> idxToId = workerIdxToName.get(sid);
    if (idxToId == null) {
        idxToId = Collections.synchronizedMap(new HashMap<>());
        workerIdxToName.put(sid, idxToId);
    }

    if (idxToId.containsKey(workerIdx)) {
        return idxToId.get(workerIdx);
    }

    //Need to record new worker...
    //Get counter
    AtomicInteger counter = workerIdxCount.get(sid);
    if (counter == null) {
        counter = new AtomicInteger(0);
        workerIdxCount.put(sid, counter);
    }

    //Get all worker IDs
    StatsStorage ss = knownSessionIDs.get(sid);
    List<String> allWorkerIds = new ArrayList<>(ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID));
    Collections.sort(allWorkerIds);

    //Ensure all workers have been assigned an index
    for (String s : allWorkerIds) {
        if (idxToId.containsValue(s))
            continue;
        //Unknown worker ID:
        idxToId.put(counter.getAndIncrement(), s);
    }

    //May still return null if index is wrong/too high...
    return idxToId.get(workerIdx);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:40,代码来源:TrainModule.java

示例10: enableRemoteListener

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public void enableRemoteListener(StatsStorageRouter statsStorage, boolean attach) {
    remoteReceiverModule.setEnabled(true);
    remoteReceiverModule.setStatsStorage(statsStorage);
    if (attach && statsStorage instanceof StatsStorage) {
        attach((StatsStorage) statsStorage);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:9,代码来源:PlayUIServer.java

示例11: testUIMultipleSessions

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
@Ignore
public void testUIMultipleSessions() throws Exception {

    for (int session = 0; session < 3; session++) {

        StatsStorage ss = new InMemoryStatsStorage();

        UIServer uiServer = UIServer.getInstance();
        uiServer.attach(ss);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
                        .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
                        .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
                                        .activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
                        .pretrain(false).backprop(true).build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));

        DataSetIterator iter = new IrisDataSetIterator(150, 150);

        for (int i = 0; i < 20; i++) {
            net.fit(iter);
            Thread.sleep(100);
        }
    }


    Thread.sleep(1000000);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:34,代码来源:TestPlayUI.java

示例12: testUICompGraph

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
@Ignore
public void testUICompGraph() throws Exception {

    StatsStorage ss = new InMemoryStatsStorage();

    UIServer uiServer = UIServer.getInstance();
    uiServer.attach(ss);

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
                    .addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(),
                                    "in")
                    .addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0")
                    .pretrain(false).backprop(true).setOutputs("L1").build();

    ComputationGraph net = new ComputationGraph(conf);
    net.init();

    net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));

    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    for (int i = 0; i < 100; i++) {
        net.fit(iter);
        Thread.sleep(100);
    }

    Thread.sleep(100000);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:31,代码来源:TestPlayUI.java

示例13: ConvolutionalIterationListener

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
public ConvolutionalIterationListener(StatsStorageRouter ssr, int iterations, boolean openBrowser, String sessionID,
                String workerID) {
    this.ssr = ssr;
    if (sessionID == null) {
        //TODO handle syncing session IDs across different listeners in the same model...
        this.sessionID = UUID.randomUUID().toString();
    } else {
        this.sessionID = sessionID;
    }
    if (workerID == null) {
        this.workerID = UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId();
    } else {
        this.workerID = workerID;
    }

    String subPath = "activations";

    this.freq = iterations;
    this.openBrowser = openBrowser;
    path = "http://localhost:" + UIServer.getInstance().getPort() + "/" + subPath;

    if (openBrowser && ssr instanceof StatsStorage) {
        UIServer.getInstance().attach((StatsStorage) ssr);
    }

    System.out.println("ConvolutionIterationListener path: " + path);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:28,代码来源:ConvolutionalIterationListener.java

示例14: testParallelStatsListenerCompatibility

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
@Ignore //To be run manually
public void testParallelStatsListenerCompatibility() throws Exception {
    UIServer uiServer = UIServer.getInstance();

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
                    .layer(1, new OutputLayer.Builder().nIn(3).nOut(3)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);

    // it's important that the UI can report results from parallel training
    // there's potential for StatsListener to fail if certain properties aren't set in the model
    StatsStorage statsStorage = new InMemoryStatsStorage();
    net.setListeners(new StatsListener(statsStorage));
    uiServer.attach(statsStorage);

    DataSetIterator irisIter = new IrisDataSetIterator(50, 500);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(500))
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true))
                                    .evaluateEveryNEpochs(2).modelSaver(saver).build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer =
                    new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 3, 6, 2);

    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);

    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:37,代码来源:TestParallelEarlyStoppingUI.java

示例15: main

import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
public static void main(String[] args) {
    UIServer server = UIServer.getInstance();
    StatsStorage statsStorage = new InMemoryStatsStorage();
    server.attach(statsStorage);
    server.enableRemoteListener();
}
 
开发者ID:buybrain,项目名称:docker-dl4j-ui,代码行数:7,代码来源:Server.java


注:本文中的org.deeplearning4j.api.storage.StatsStorage类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。