当前位置: 首页>>代码示例>>Java>>正文


Java ComputationGraph.getLayer方法代码示例

本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.getLayer方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.getLayer方法的具体用法?Java ComputationGraph.getLayer怎么用?Java ComputationGraph.getLayer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.nn.graph.ComputationGraph的用法示例。


在下文中一共展示了ComputationGraph.getLayer方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: getVaeLayer

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Override
public VariationalAutoencoder getVaeLayer() {
    ComputationGraph network =
                    new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue()));
    network.init();
    INDArray val = ((INDArray) params.value()).unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcasted set parameters");
    network.setParams(val);

    Layer l = network.getLayer(0);
    if (!(l instanceof VariationalAutoencoder)) {
        throw new RuntimeException(
                        "Cannot use CGVaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE "
                                        + "layer as layer 0. Layer type: " + l.getClass());
    }
    return (VariationalAutoencoder) l;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:CGVaeReconstructionErrorWithKeyFunction.java

示例2: getVaeLayer

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Override
public VariationalAutoencoder getVaeLayer() {
    ComputationGraph network =
                    new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue()));
    network.init();
    INDArray val = ((INDArray) params.value()).unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcasted set parameters");
    network.setParams(val);

    Layer l = network.getLayer(0);
    if (!(l instanceof VariationalAutoencoder)) {
        throw new RuntimeException(
                        "Cannot use CGVaeReconstructionProbWithKeyFunction on network that doesn't have a VAE "
                                        + "layer as layer 0. Layer type: " + l.getClass());
    }
    return (VariationalAutoencoder) l;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:CGVaeReconstructionProbWithKeyFunction.java

示例3: testLastTimeStepVertex

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testLastTimeStepVertex() {

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
            .addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder()
                    .nIn(5).nOut(6).build()), "in")
            .setOutputs("lastTS")
            .build();

    ComputationGraph graph = new ComputationGraph(conf);
    graph.init();

    //First: test without input mask array
    Nd4j.getRandom().setSeed(12345);
    Layer l = graph.getLayer("lastTS");
    INDArray in = Nd4j.rand(new int[]{3, 5, 6});
    INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in);
    INDArray expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5));


    //Forward pass:
    INDArray outFwd = l.activate(in);
    assertEquals(expOut, outFwd);

    //Second: test with input mask array
    INDArray inMask = Nd4j.zeros(3, 6);
    inMask.putRow(0, Nd4j.create(new double[]{1, 1, 1, 0, 0, 0}));
    inMask.putRow(1, Nd4j.create(new double[]{1, 1, 1, 1, 0, 0}));
    inMask.putRow(2, Nd4j.create(new double[]{1, 1, 1, 1, 1, 0}));
    graph.setLayerMaskArrays(new INDArray[]{inMask}, null);

    expOut = Nd4j.zeros(3, 6);
    expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2)));
    expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3)));
    expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4)));

    outFwd = l.activate(in);
    assertEquals(expOut, outFwd);

    TestUtils.testModelSerialization(graph);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:42,代码来源:TestLastTimeStepLayer.java


注:本文中的org.deeplearning4j.nn.graph.ComputationGraph.getLayer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。