当前位置: 首页>>代码示例>>Java>>正文


Java ComputationGraph.rnnTimeStep方法代码示例

本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.rnnTimeStep方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.rnnTimeStep方法的具体用法?Java ComputationGraph.rnnTimeStep怎么用?Java ComputationGraph.rnnTimeStep使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.nn.graph.ComputationGraph的用法示例。


在下文中一共展示了ComputationGraph.rnnTimeStep方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testRnnTimeStepWithPreprocessorGraph

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testRnnTimeStepWithPreprocessorGraph() {

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .graphBuilder().addInputs("in")
                    .addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10).nOut(10)
                                    .activation(Activation.TANH).build(), "in")
                    .addLayer("1", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10).nOut(10)
                                    .activation(Activation.TANH).build(), "0")
                    .addLayer("2", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                                    .activation(Activation.SOFTMAX).nIn(10).nOut(10).build(), "1")
                    .setOutputs("2").inputPreProcessor("0", new FeedForwardToRnnPreProcessor()).pretrain(false)
                    .backprop(true).build();

    ComputationGraph net = new ComputationGraph(conf);
    net.init();

    INDArray in = Nd4j.rand(1, 10);
    net.rnnTimeStep(in);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:MultiLayerTestRNN.java

示例2: testRnnTimeStep

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testRnnTimeStep(){
    for(WorkspaceMode ws : WorkspaceMode.values()) {
        for (int i = 0; i < 3; i++) {

            System.out.println("Starting test: " + ws + " - " + i);

            NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .activation(Activation.TANH)
                    .inferenceWorkspaceMode(ws)
                    .trainingWorkspaceMode(ws)
                    .list();

            ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .activation(Activation.TANH)
                    .inferenceWorkspaceMode(ws)
                    .trainingWorkspaceMode(ws)
                    .graphBuilder()
                    .addInputs("in");

            switch (i) {
                case 0:
                    b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build());
                    b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build());

                    gb.addLayer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in");
                    gb.addLayer("1", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "0");
                    break;
                case 1:
                    b.layer(new LSTM.Builder().nIn(10).nOut(10).build());
                    b.layer(new LSTM.Builder().nIn(10).nOut(10).build());

                    gb.addLayer("0", new LSTM.Builder().nIn(10).nOut(10).build(), "in");
                    gb.addLayer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0");
                    break;
                case 2:
                    b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build());
                    b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build());

                    gb.addLayer("0", new GravesLSTM.Builder().nIn(10).nOut(10).build(), "in");
                    gb.addLayer("1", new GravesLSTM.Builder().nIn(10).nOut(10).build(), "0");
                    break;
                default:
                    throw new RuntimeException();
            }

            b.layer(new RnnOutputLayer.Builder().nIn(10).nOut(10).build());
            gb.addLayer("out", new RnnOutputLayer.Builder().nIn(10).nOut(10).build(), "1");
            gb.setOutputs("out");

            MultiLayerConfiguration conf = b.build();
            ComputationGraphConfiguration conf2 = gb.build();


            MultiLayerNetwork net = new MultiLayerNetwork(conf);
            net.init();

            ComputationGraph net2 = new ComputationGraph(conf2);
            net2.init();

            for( int j=0; j<3; j++ ){
                net.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5}));
            }

            for( int j=0; j<3; j++ ){
                net2.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5}));
            }
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:73,代码来源:WorkspaceTests.java


注:本文中的org.deeplearning4j.nn.graph.ComputationGraph.rnnTimeStep方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。