本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.rnnTimeStep方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.rnnTimeStep方法的具体用法?Java ComputationGraph.rnnTimeStep怎么用?Java ComputationGraph.rnnTimeStep使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.deeplearning4j.nn.graph.ComputationGraph
的用法示例。
在下文中一共展示了ComputationGraph.rnnTimeStep方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: testRnnTimeStepWithPreprocessorGraph
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testRnnTimeStepWithPreprocessorGraph() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.graphBuilder().addInputs("in")
.addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10).nOut(10)
.activation(Activation.TANH).build(), "in")
.addLayer("1", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10).nOut(10)
.activation(Activation.TANH).build(), "0")
.addLayer("2", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1")
.setOutputs("2").inputPreProcessor("0", new FeedForwardToRnnPreProcessor()).pretrain(false)
.backprop(true).build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
INDArray in = Nd4j.rand(1, 10);
net.rnnTimeStep(in);
}
示例2: testRnnTimeStep
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testRnnTimeStep(){
for(WorkspaceMode ws : WorkspaceMode.values()) {
for (int i = 0; i < 3; i++) {
System.out.println("Starting test: " + ws + " - " + i);
NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.list();
ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.inferenceWorkspaceMode(ws)
.trainingWorkspaceMode(ws)
.graphBuilder()
.addInputs("in");
switch (i) {
case 0:
b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build());
b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build());
gb.addLayer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in");
gb.addLayer("1", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "0");
break;
case 1:
b.layer(new LSTM.Builder().nIn(10).nOut(10).build());
b.layer(new LSTM.Builder().nIn(10).nOut(10).build());
gb.addLayer("0", new LSTM.Builder().nIn(10).nOut(10).build(), "in");
gb.addLayer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0");
break;
case 2:
b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build());
b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build());
gb.addLayer("0", new GravesLSTM.Builder().nIn(10).nOut(10).build(), "in");
gb.addLayer("1", new GravesLSTM.Builder().nIn(10).nOut(10).build(), "0");
break;
default:
throw new RuntimeException();
}
b.layer(new RnnOutputLayer.Builder().nIn(10).nOut(10).build());
gb.addLayer("out", new RnnOutputLayer.Builder().nIn(10).nOut(10).build(), "1");
gb.setOutputs("out");
MultiLayerConfiguration conf = b.build();
ComputationGraphConfiguration conf2 = gb.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
for( int j=0; j<3; j++ ){
net.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5}));
}
for( int j=0; j<3; j++ ){
net2.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5}));
}
}
}
}