本文整理汇总了Java中org.deeplearning4j.api.storage.StatsStorage类的典型用法代码示例。如果您正苦于以下问题:Java StatsStorage类的具体用法?Java StatsStorage怎么用?Java StatsStorage使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
StatsStorage类属于org.deeplearning4j.api.storage包,在下文中一共展示了StatsStorage类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: getDefaultSession
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private void getDefaultSession() {
if (currentSessionID != null)
return;
long mostRecentTime = Long.MIN_VALUE;
String sessionID = null;
for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
List<Persistable> staticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), StatsListener.TYPE_ID);
if (staticInfos == null || staticInfos.isEmpty())
continue;
Persistable p = staticInfos.get(0);
long thisTime = p.getTimeStamp();
if (thisTime > mostRecentTime) {
mostRecentTime = thisTime;
sessionID = entry.getKey();
}
}
if (sessionID != null) {
currentSessionID = sessionID;
}
}
示例2: getModelGraph
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private Result getModelGraph() {
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
if (allStatic.isEmpty()) {
return ok();
}
TrainModuleUtils.GraphInfo gi = getGraphInfo();
if (gi == null)
return ok();
return ok(Json.toJson(gi));
}
示例3: attach
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public synchronized void attach(StatsStorage statsStorage) {
if (statsStorage == null)
throw new IllegalArgumentException("StatsStorage cannot be null");
if (statsStorageInstances.contains(statsStorage))
return;
StatsStorageListener listener = new QueueStatsStorageListener(eventQueue);
listeners.add(new Pair<>(statsStorage, listener));
statsStorage.registerStatsStorageListener(listener);
statsStorageInstances.add(statsStorage);
for (UIModule uiModule : uiModules) {
uiModule.onAttach(statsStorage);
}
log.info("StatsStorage instance attached to UI: {}", statsStorage);
}
示例4: detach
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public synchronized void detach(StatsStorage statsStorage) {
if (statsStorage == null)
throw new IllegalArgumentException("StatsStorage cannot be null");
if (!statsStorageInstances.contains(statsStorage))
return; //No op
boolean found = false;
for (Iterator<Pair<StatsStorage, StatsStorageListener>> iterator = listeners.iterator(); iterator.hasNext();) {
Pair<StatsStorage, StatsStorageListener> p = iterator.next();
if (p.getFirst() == statsStorage) { //Same object, not equality
statsStorage.deregisterStatsStorageListener(p.getSecond());
iterator.remove();
found = true;
}
}
for (UIModule uiModule : uiModules) {
uiModule.onDetach(statsStorage);
}
if (found) {
log.info("StatsStorage instance detached from UI: {}", statsStorage);
}
}
示例5: testListenersViaModel
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
public void testListenersViaModel() {
TestListener.clearCounts();
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0,
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10)
.activation(Activation.TANH).build());
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
StatsStorage ss = new InMemoryStatsStorage();
model.setListeners(new TestListener(), new StatsListener(ss));
testListenersForModel(model, null);
assertEquals(1, ss.listSessionIDs().size());
assertEquals(2, ss.listWorkerIDsForSession(ss.listSessionIDs().get(0)).size());
}
示例6: testListenersViaModelGraph
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
public void testListenersViaModelGraph() {
TestListener.clearCounts();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder()
.addInputs("in").addLayer("0",
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10)
.activation(Activation.TANH).build(),
"in")
.setOutputs("0").build();
ComputationGraph model = new ComputationGraph(conf);
model.init();
StatsStorage ss = new InMemoryStatsStorage();
model.setListeners(new TestListener(), new StatsListener(ss));
testListenersForModel(model, null);
assertEquals(1, ss.listSessionIDs().size());
assertEquals(2, ss.listWorkerIDsForSession(ss.listSessionIDs().get(0)).size());
}
示例7: onAttach
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public synchronized void onAttach(StatsStorage statsStorage) {
for (String sessionID : statsStorage.listSessionIDs()) {
for (String typeID : statsStorage.listTypeIDsForSession(sessionID)) {
if (!StatsListener.TYPE_ID.equals(typeID))
continue;
knownSessionIDs.put(sessionID, statsStorage);
}
}
if (currentSessionID == null)
getDefaultSession();
}
示例8: onDetach
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public void onDetach(StatsStorage statsStorage) {
for (String s : knownSessionIDs.keySet()) {
if (knownSessionIDs.get(s) == statsStorage) {
knownSessionIDs.remove(s);
}
}
}
示例9: getWorkerIdForIndex
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private synchronized String getWorkerIdForIndex(int workerIdx) {
String sid = currentSessionID;
if (sid == null)
return null;
Map<Integer, String> idxToId = workerIdxToName.get(sid);
if (idxToId == null) {
idxToId = Collections.synchronizedMap(new HashMap<>());
workerIdxToName.put(sid, idxToId);
}
if (idxToId.containsKey(workerIdx)) {
return idxToId.get(workerIdx);
}
//Need to record new worker...
//Get counter
AtomicInteger counter = workerIdxCount.get(sid);
if (counter == null) {
counter = new AtomicInteger(0);
workerIdxCount.put(sid, counter);
}
//Get all worker IDs
StatsStorage ss = knownSessionIDs.get(sid);
List<String> allWorkerIds = new ArrayList<>(ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID));
Collections.sort(allWorkerIds);
//Ensure all workers have been assigned an index
for (String s : allWorkerIds) {
if (idxToId.containsValue(s))
continue;
//Unknown worker ID:
idxToId.put(counter.getAndIncrement(), s);
}
//May still return null if index is wrong/too high...
return idxToId.get(workerIdx);
}
示例10: enableRemoteListener
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public void enableRemoteListener(StatsStorageRouter statsStorage, boolean attach) {
remoteReceiverModule.setEnabled(true);
remoteReceiverModule.setStatsStorage(statsStorage);
if (attach && statsStorage instanceof StatsStorage) {
attach((StatsStorage) statsStorage);
}
}
示例11: testUIMultipleSessions
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
@Ignore
public void testUIMultipleSessions() throws Exception {
for (int session = 0; session < 3; session++) {
StatsStorage ss = new InMemoryStatsStorage();
UIServer uiServer = UIServer.getInstance();
uiServer.attach(ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 20; i++) {
net.fit(iter);
Thread.sleep(100);
}
}
Thread.sleep(1000000);
}
示例12: testUICompGraph
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
@Ignore
public void testUICompGraph() throws Exception {
StatsStorage ss = new InMemoryStatsStorage();
UIServer uiServer = UIServer.getInstance();
uiServer.attach(ss);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(),
"in")
.addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0")
.pretrain(false).backprop(true).setOutputs("L1").build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 100; i++) {
net.fit(iter);
Thread.sleep(100);
}
Thread.sleep(100000);
}
示例13: ConvolutionalIterationListener
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
public ConvolutionalIterationListener(StatsStorageRouter ssr, int iterations, boolean openBrowser, String sessionID,
String workerID) {
this.ssr = ssr;
if (sessionID == null) {
//TODO handle syncing session IDs across different listeners in the same model...
this.sessionID = UUID.randomUUID().toString();
} else {
this.sessionID = sessionID;
}
if (workerID == null) {
this.workerID = UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId();
} else {
this.workerID = workerID;
}
String subPath = "activations";
this.freq = iterations;
this.openBrowser = openBrowser;
path = "http://localhost:" + UIServer.getInstance().getPort() + "/" + subPath;
if (openBrowser && ssr instanceof StatsStorage) {
UIServer.getInstance().attach((StatsStorage) ssr);
}
System.out.println("ConvolutionIterationListener path: " + path);
}
示例14: testParallelStatsListenerCompatibility
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
@Ignore //To be run manually
public void testParallelStatsListenerCompatibility() throws Exception {
UIServer uiServer = UIServer.getInstance();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1, new OutputLayer.Builder().nIn(3).nOut(3)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
// it's important that the UI can report results from parallel training
// there's potential for StatsListener to fail if certain properties aren't set in the model
StatsStorage statsStorage = new InMemoryStatsStorage();
net.setListeners(new StatsListener(statsStorage));
uiServer.attach(statsStorage);
DataSetIterator irisIter = new IrisDataSetIterator(50, 500);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(500))
.scoreCalculator(new DataSetLossCalculator(irisIter, true))
.evaluateEveryNEpochs(2).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer =
new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 3, 6, 2);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
System.out.println(result);
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
示例15: main
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
public static void main(String[] args) {
UIServer server = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
server.attach(statsStorage);
server.enableRemoteListener();
}