当前位置: 首页>>代码示例>>Java>>正文


Java ComputationGraph.setInput方法代码示例

本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.setInput方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.setInput方法的具体用法?Java ComputationGraph.setInput怎么用?Java ComputationGraph.setInput使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.nn.graph.ComputationGraph的用法示例。


在下文中一共展示了ComputationGraph.setInput方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testLambdaConf

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testLambdaConf() {
    double[] lambdas = new double[] {0.1, 0.01};
    double[] results = new double[2];
    int numClasses = 2;

    INDArray input = Nd4j.rand(150, 4);
    INDArray labels = Nd4j.zeros(150, numClasses);
    Random r = new Random(12345);
    for (int i = 0; i < 150; i++) {
        labels.putScalar(i, r.nextInt(numClasses), 1.0);
    }
    ComputationGraph graph;

    for (int i = 0; i < lambdas.length; i++) {
        graph = getGraph(numClasses, lambdas[i]);
        graph.setInput(0, input);
        graph.setLabel(0, labels);
        graph.computeGradientAndScore();
        results[i] = graph.score();
    }

    assertNotEquals(results[0], results[1]);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:25,代码来源:CenterLossOutputLayerTest.java

示例2: testSerializationCompGraph

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testSerializationCompGraph() throws Exception {

    for(WorkspaceMode wsm : WorkspaceMode.values()) {
        log.info("*** Starting workspace mode: " + wsm);

        Nd4j.getRandom().setSeed(12345);

        ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .trainingWorkspaceMode(wsm)
                .inferenceWorkspaceMode(wsm)
                .updater(new Adam())
                .graphBuilder()
                .addInputs("in")
                .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in")
                .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0")
                .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
                        .nIn(10).nOut(10).build(), "1")
                .setOutputs("2")
                .build();

        ComputationGraph net1 = new ComputationGraph(conf1);
        net1.init();

        INDArray in = Nd4j.rand(new int[]{3, 10, 5});
        INDArray labels = Nd4j.rand(new int[]{3, 10, 5});

        net1.fit(new DataSet(in, labels));

        byte[] bytes;
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
            ModelSerializer.writeModel(net1, baos, true);
            bytes = baos.toByteArray();
        }


        ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);


        in = Nd4j.rand(new int[]{3, 10, 5});
        labels = Nd4j.rand(new int[]{3, 10, 5});

        INDArray out1 = net1.outputSingle(in);
        INDArray out2 = net2.outputSingle(in);

        assertEquals(out1, out2);

        net1.setInput(0, in);
        net2.setInput(0, in);
        net1.setLabels(labels);
        net2.setLabels(labels);

        net1.computeGradientAndScore();
        net2.computeGradientAndScore();

        assertEquals(net1.score(), net2.score(), 1e-6);
        assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:62,代码来源:BidirectionalTest.java


注:本文中的org.deeplearning4j.nn.graph.ComputationGraph.setInput方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。