当前位置: 首页>>代码示例>>Java>>正文


Java ComputationGraph.init方法代码示例

本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.init方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.init方法的具体用法?Java ComputationGraph.init怎么用?Java ComputationGraph.init使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.nn.graph.ComputationGraph的用法示例。


在下文中一共展示了ComputationGraph.init方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: getOutput

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
private INDArray getOutput(InputStream isModel, INDArray image) {
		org.deeplearning4j.nn.api.Model dl4jModel;
		try {
			// won't use the model guesser at the moment because it is trying to load a keras model?
//			dl4jModel = ModelGuesser.loadModelGuess(isModel);
			dl4jModel = loadModel(isModel);
		} catch (Exception e) {
			throw new IllegalArgumentException("Not able to load model.", e);
		}
		if(dl4jModel instanceof MultiLayerNetwork) {
			MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) dl4jModel;
			multiLayerNetwork.init();
			return multiLayerNetwork.output(image);
		} else {
			ComputationGraph graph = (ComputationGraph) dl4jModel;
			graph.init();
			return graph.output(image)[0];
		}
	}
 
开发者ID:jesuino,项目名称:kie-ml,代码行数:20,代码来源:DL4JKieMLProvider.java

示例2: getVaeLayer

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Override
public VariationalAutoencoder getVaeLayer() {
    ComputationGraph network =
                    new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue()));
    network.init();
    INDArray val = ((INDArray) params.value()).unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcasted set parameters");
    network.setParams(val);

    Layer l = network.getLayer(0);
    if (!(l instanceof VariationalAutoencoder)) {
        throw new RuntimeException(
                        "Cannot use CGVaeReconstructionProbWithKeyFunction on network that doesn't have a VAE "
                                        + "layer as layer 0. Layer type: " + l.getClass());
    }
    return (VariationalAutoencoder) l;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:CGVaeReconstructionProbWithKeyFunction.java

示例3: testWriteCGModelInputStream

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testWriteCGModelInputStream() throws Exception {
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1))
                    .graphBuilder().addInputs("in")
                    .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out",
                                    new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3)
                                                    .build(),
                                    "dense")
                    .setOutputs("out").pretrain(false).backprop(true).build();

    ComputationGraph cg = new ComputationGraph(config);
    cg.init();

    File tempFile = File.createTempFile("tsfs", "fdfsdf");
    tempFile.deleteOnExit();

    ModelSerializer.writeModel(cg, tempFile, true);
    FileInputStream fis = new FileInputStream(tempFile);

    ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);

    assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson());
    assertEquals(cg.params(), network.params());
    assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:ModelSerializerTest.java

示例4: createModel

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
/**
 * Build the multilayer network defined by the networkconfiguration and the list of layers.
 *
 * @throws Exception
 */
protected void createModel() throws Exception {
  final INDArray features = getFirstBatchFeatures(trainData);
  ComputationGraphConfiguration.GraphBuilder gb =
      netConfig
          .builder()
          .seed(getSeed())
          .inferenceWorkspaceMode(WorkspaceMode.SEPARATE)
          .trainingWorkspaceMode(WorkspaceMode.SEPARATE)
          .graphBuilder();

  // Set ouput size
  final Layer lastLayer = layers[layers.length - 1];
  final int nOut = trainData.numClasses();
  if (lastLayer instanceof BaseOutputLayer) {
    ((BaseOutputLayer) lastLayer).setNOut(nOut);
  }

  if (getInstanceIterator() instanceof CnnTextEmbeddingInstanceIterator){
    makeCnnTextLayerSetup(gb);
  } else {
    makeDefaultLayerSetup(gb);
  }

  gb.setInputTypes(InputType.inferInputType(features));
  ComputationGraphConfiguration conf = gb.pretrain(false).backprop(true).build();
  ComputationGraph model = new ComputationGraph(conf);
  model.init();
  this.model = model;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:35,代码来源:Dl4jMlpClassifier.java

示例5: setZooModel

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
/**
 * Set the modelzoo zooModel
 *
 * @param zooModel The predefined zooModel
 */
@OptionMetadata(
  displayName = "zooModel",
  description = "The model-architecture to choose from the modelzoo " + "(default = no model).",
  commandLineParamName = "zooModel",
  commandLineParamSynopsis = "-zooModel <string>",
  displayOrder = 11
)
public void setZooModel(ZooModel zooModel) {
  if (zooModel instanceof GoogLeNet || zooModel instanceof FaceNetNN4Small2) {
    throw new RuntimeException(
        "The zoomodel you have selected is currently"
            + " not supported! Please select another one.");
  }

  this.zooModel = zooModel;

  try {
    // Try to parse the layers so the user can change them afterwards
    final int dummyNumLabels = 2;
    ComputationGraph tmpCg = zooModel.init(dummyNumLabels, getSeed(), zooModel.getShape());
    tmpCg.init();
    layers =
        Arrays.stream(tmpCg.getLayers())
            .map(l -> l.conf().getLayer())
            .collect(Collectors.toList())
            .toArray(new Layer[tmpCg.getLayers().length]);
  } catch (Exception e) {
    if (!(zooModel instanceof CustomNet)) {
      log.error("Could not set layers from zoomodel.", e);
    }
  }
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:38,代码来源:Dl4jMlpClassifier.java

示例6: createModel

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Override
protected void createModel() throws Exception {
  final INDArray features = getFirstBatchFeatures(trainData);
  log.info("Feature shape: {}", features.shape());
  ComputationGraphConfiguration.GraphBuilder gb =
      netConfig
          .builder()
          .seed(getSeed())
          .inferenceWorkspaceMode(WorkspaceMode.SEPARATE)
          .trainingWorkspaceMode(WorkspaceMode.SEPARATE)
          .graphBuilder()
          .backpropType(BackpropType.TruncatedBPTT)
          .tBPTTBackwardLength(tBPTTbackwardLength)
          .tBPTTForwardLength(tBPTTforwardLength);

  // Set ouput size
  final Layer lastLayer = layers[layers.length - 1];
  final int nOut = trainData.numClasses();
  if (lastLayer instanceof RnnOutputLayer) {
    ((RnnOutputLayer) lastLayer).setNOut(nOut);
  }

  String currentInput = "input";
  gb.addInputs(currentInput);
  // Collect layers
  for (Layer layer : layers) {
    String lName = layer.getLayerName();
    gb.addLayer(lName, layer, currentInput);
    currentInput = lName;
  }
  gb.setOutputs(currentInput);
  gb.setInputTypes(InputType.inferInputType(features));

  ComputationGraphConfiguration conf = gb.pretrain(false).backprop(true).build();
  ComputationGraph model = new ComputationGraph(conf);
  model.init();
  this.model = model;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:39,代码来源:RnnSequenceClassifier.java

示例7: testAddOutput

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testAddOutput() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.9))
                    .activation(Activation.IDENTITY);

    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight")
                    .addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre")
                    .addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inRight")
                    .addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0")
                    .addLayer("outRight",
                                    new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(),
                                    "mergeRight")
                    .setOutputs("outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();

    ComputationGraph modelNow =
                    new TransferLearning.GraphBuilder(modelToTune)
                                    .addLayer("outCentre",
                                                    new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(2)
                                                                    .nOut(3).build(),
                                                    "denseCentre0")
                                    .setOutputs("outRight", "outCentre").build();

    assertEquals(2, modelNow.getNumOutputArrays());
    MultiDataSet rand = new MultiDataSet(new INDArray[] {Nd4j.rand(2, 2), Nd4j.rand(2, 2)},
                    new INDArray[] {Nd4j.rand(2, 2), Nd4j.rand(2, 3)});
    modelNow.fit(rand);
    log.info(modelNow.summary());
    log.info(modelNow.summary(InputType.feedForward(2),InputType.feedForward(2)));

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:33,代码来源:TransferLearningComplex.java

示例8: init

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Override
public ComputationGraph init() {
    ComputationGraph model = new ComputationGraph(conf());
    model.init();

    return model;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:8,代码来源:GoogLeNet.java

示例9: testLSTMWithSubset

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testLSTMWithSubset() {
    Nd4j.getRandom().setSeed(1234);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(1234)
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 1))
                    .updater(new NoOp()).graphBuilder().addInputs("input").setOutputs("out")
                    .addLayer("lstm1", new GravesLSTM.Builder().nIn(3).nOut(8).activation(Activation.TANH).build(),
                                    "input")
                    .addVertex("subset", new SubsetVertex(0, 3), "lstm1")
                    .addLayer("out", new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "subset")
                    .pretrain(false).backprop(true).build();

    ComputationGraph graph = new ComputationGraph(conf);
    graph.init();

    Random r = new Random(12345);
    INDArray input = Nd4j.rand(new int[] {3, 3, 5});
    INDArray labels = Nd4j.zeros(3, 3, 5);
    for (int i = 0; i < 3; i++) {
        for (int j = 0; j < 5; j++) {
            labels.putScalar(new int[] {i, r.nextInt(3), j}, 1.0);
        }
    }

    if (PRINT_RESULTS) {
        System.out.println("testLSTMWithSubset()");
        for (int j = 0; j < graph.getNumLayers(); j++)
            System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
    }

    boolean gradOK = GradientCheckUtil.checkGradients(graph, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
                    DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, new INDArray[] {input},
                    new INDArray[] {labels});

    String msg = "testLSTMWithSubset()";
    assertTrue(msg, gradOK);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:40,代码来源:GradientCheckTestsComputationGraph.java

示例10: testElementWiseVertexForwardSubtract

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testElementWiseVertexForwardSubtract() {
    int batchsz = 24;
    int featuresz = 17;
    ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder()
                    .addInputs("input1", "input2")
                    .addLayer("denselayer",
                                    new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY)
                                                    .build(),
                                    "input1")
                    /* denselayer is not actually used, but it seems that you _need_ to have trainable parameters, otherwise, you get
                     * Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
                     * at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4877)
                     * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4867)
                     * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:4820)
                     * at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3948)
                     * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:409)
                     * at org.deeplearning4j.nn.graph.ComputationGraph.init(ComputationGraph.java:341)
                     */
                    .addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract),
                                    "input1", "input2")
                    .addLayer("Subtract", new ActivationLayer.Builder().activation(Activation.IDENTITY).build(),
                                    "elementwiseSubtract")
                    .setOutputs("Subtract", "denselayer").build();

    ComputationGraph cg = new ComputationGraph(cgc);
    cg.init();


    INDArray input1 = Nd4j.rand(batchsz, featuresz);
    INDArray input2 = Nd4j.rand(batchsz, featuresz);

    INDArray target = input1.dup().subi(input2);

    INDArray output = cg.output(input1, input2)[0];
    INDArray squared = output.sub(target);
    double rms = Math.sqrt(squared.mul(squared).sumNumber().doubleValue());
    Assert.assertEquals(0.0, rms, this.epsilon);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:40,代码来源:ElementWiseVertexTest.java

示例11: testWithPreprocessorsCG

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
    public void testWithPreprocessorsCG(){
        //https://github.com/deeplearning4j/deeplearning4j/issues/4347
        //Cause for the above issue was layerVertex.setInput() applying the preprocessor, with the result
        // not being detached properly from the workspace...

        for(WorkspaceMode wm : WorkspaceMode.values()) {
            System.out.println(wm);
            ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                    .trainingWorkspaceMode(wm)
                    .inferenceWorkspaceMode(wm)
                    .graphBuilder()
                    .addInputs("in")
                    .addLayer("e", new GravesLSTM.Builder().nIn(10).nOut(5).build(), new DupPreProcessor(), "in")
//                .addLayer("e", new GravesLSTM.Builder().nIn(10).nOut(5).build(), "in")    //Note that no preprocessor is OK
                    .addLayer("rnn", new GravesLSTM.Builder().nIn(5).nOut(8).build(), "e")
                    .addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                            .activation(Activation.SIGMOID).nOut(3).build(), "rnn")
                    .setInputTypes(InputType.recurrent(10))
                    .setOutputs("out")
                    .build();

            ComputationGraph cg = new ComputationGraph(conf);
            cg.init();


            INDArray[] input = new INDArray[]{Nd4j.zeros(1, 10, 5)};

            for( boolean train : new boolean[]{false, true}){
                cg.clear();
                cg.feedForward(input, train);
            }

            cg.setInputs(input);
            cg.setLabels(Nd4j.rand(1, 3, 5));
            cg.computeGradientAndScore();
        }
    }
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:39,代码来源:WorkspaceTests.java

示例12: buildCNNGraph

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
public static ComputationGraph buildCNNGraph (int vectorSize, int cnnLayerFeatureMaps, PoolingType globalPoolingType) {
    // Set up the network configuration. Note that we have multiple convolution layers, each wih filter
    // widths of 3, 4 and 5 as per Kim (2014) paper.
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
            .weightInit(WeightInit.RELU)
            .activation(Activation.LEAKYRELU)
            .updater(Updater.ADAM)
            .convolutionMode(ConvolutionMode.Same)      //This is important so we can 'stack' the results later
            .regularization(true).l2(0.0001)
            .learningRate(0.01)
            .graphBuilder()
            .addInputs("input")
            .addLayer("cnn3", new ConvolutionLayer.Builder()
                    .kernelSize(3, vectorSize)
                    .stride(1, vectorSize)
                    .nIn(1)
                    .nOut(cnnLayerFeatureMaps)
                    .build(), "input")
            .addLayer("cnn4", new ConvolutionLayer.Builder()
                    .kernelSize(4, vectorSize)
                    .stride(1, vectorSize)
                    .nIn(1)
                    .nOut(cnnLayerFeatureMaps)
                    .build(), "input")
            .addLayer("cnn5", new ConvolutionLayer.Builder()
                    .kernelSize(5, vectorSize)
                    .stride(1, vectorSize)
                    .nIn(1)
                    .nOut(cnnLayerFeatureMaps)
                    .build(), "input")
            //Perform depth concatenation
            .addVertex("merge", new MergeVertex(), "cnn3", "cnn4", "cnn5")
            .addLayer("globalPool", new GlobalPoolingLayer.Builder()
                    .poolingType(globalPoolingType)
                    .build(), "merge")
            .addLayer("out", new OutputLayer.Builder()
                    .lossFunction(LossFunctions.LossFunction.MCXENT)
                    .activation(Activation.SOFTMAX)
                    .nIn(3 * cnnLayerFeatureMaps)
                    .nOut(2)    //2 classes: positive or negative
                    .build(), "globalPool")
            .setOutputs("out")
            .build();

    ComputationGraph net = new ComputationGraph(config);
    net.init();
    return net;
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:49,代码来源:CNNSentenceClassification.java

示例13: createComputationGraph

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
/**
 * Configure and initialize the computation graph. This is done once in the
 * beginning to prepare the computation graph for training.
 */
public static ComputationGraph createComputationGraph (Map<String, Double> dict) {
    final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
            .iterations(1)
            .learningRate(LEARNING_RATE)
            .rmsDecay(RMS_DECAY)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .miniBatch(true)
            .updater(Updater.RMSPROP)
            .weightInit(WeightInit.XAVIER)
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer);

    final ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder()
            .addInputs("inputLine", "decoderInput")
            .setInputTypes(InputType.recurrent(dict.size()), InputType.recurrent(dict.size()))
            .addLayer("embeddingEncoder",
                    new EmbeddingLayer.Builder()
                            .nIn(dict.size())
                            .nOut(EMBEDDING_WIDTH)
                            .build(),
                    "inputLine")
            .addLayer("encoder",
                    new GravesLSTM.Builder()
                            .nIn(EMBEDDING_WIDTH)
                            .nOut(HIDDEN_LAYER_WIDTH)
                            .activation(Activation.TANH)
                            .gateActivationFunction(Activation.HARDSIGMOID)
                            .build(),
                    "embeddingEncoder")
            .addVertex("thoughtVector",
                    new LastTimeStepVertex("inputLine"),
                    "encoder")
            .addVertex("dup",
                    new DuplicateToTimeSeriesVertex("decoderInput"),
                    "thoughtVector")
            .addVertex("merge",
                    new MergeVertex(),
                    "decoderInput",
                    "dup")
            .addLayer("decoder",
                    new GravesLSTM.Builder()
                            .nIn(dict.size() + HIDDEN_LAYER_WIDTH)
                            .nOut(HIDDEN_LAYER_WIDTH)
                            .activation(Activation.TANH)
                            .gateActivationFunction(Activation.HARDSIGMOID) // always be a (hard) sigmoid function
                            .build(),
                    "merge")
            .addLayer("output",
                    new RnnOutputLayer.Builder()
                            .nIn(HIDDEN_LAYER_WIDTH)
                            .nOut(dict.size())
                            .activation(Activation.SOFTMAX)
                            .lossFunction(LossFunctions.LossFunction.MCXENT) // multi-class cross entropy
                            .build(),
                    "decoder")
            .setOutputs("output")
            .backpropType(BackpropType.Standard) // why not BackpropType.TruncatedBPTT
            .tBPTTForwardLength(TBPTT_SIZE)
            .tBPTTBackwardLength(TBPTT_SIZE)
            .pretrain(false)
            .backprop(true);

    ComputationGraph net = new ComputationGraph(graphBuilder.build());
    net.init();
    return net;
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:70,代码来源:ConstructGraph.java

示例14: testRnnTimeStep

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testRnnTimeStep(){
    for(WorkspaceMode ws : WorkspaceMode.values()) {
        for (int i = 0; i < 3; i++) {

            System.out.println("Starting test: " + ws + " - " + i);

            NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .activation(Activation.TANH)
                    .inferenceWorkspaceMode(ws)
                    .trainingWorkspaceMode(ws)
                    .list();

            ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder()
                    .weightInit(WeightInit.XAVIER)
                    .activation(Activation.TANH)
                    .inferenceWorkspaceMode(ws)
                    .trainingWorkspaceMode(ws)
                    .graphBuilder()
                    .addInputs("in");

            switch (i) {
                case 0:
                    b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build());
                    b.layer(new SimpleRnn.Builder().nIn(10).nOut(10).build());

                    gb.addLayer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in");
                    gb.addLayer("1", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "0");
                    break;
                case 1:
                    b.layer(new LSTM.Builder().nIn(10).nOut(10).build());
                    b.layer(new LSTM.Builder().nIn(10).nOut(10).build());

                    gb.addLayer("0", new LSTM.Builder().nIn(10).nOut(10).build(), "in");
                    gb.addLayer("1", new LSTM.Builder().nIn(10).nOut(10).build(), "0");
                    break;
                case 2:
                    b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build());
                    b.layer(new GravesLSTM.Builder().nIn(10).nOut(10).build());

                    gb.addLayer("0", new GravesLSTM.Builder().nIn(10).nOut(10).build(), "in");
                    gb.addLayer("1", new GravesLSTM.Builder().nIn(10).nOut(10).build(), "0");
                    break;
                default:
                    throw new RuntimeException();
            }

            b.layer(new RnnOutputLayer.Builder().nIn(10).nOut(10).build());
            gb.addLayer("out", new RnnOutputLayer.Builder().nIn(10).nOut(10).build(), "1");
            gb.setOutputs("out");

            MultiLayerConfiguration conf = b.build();
            ComputationGraphConfiguration conf2 = gb.build();


            MultiLayerNetwork net = new MultiLayerNetwork(conf);
            net.init();

            ComputationGraph net2 = new ComputationGraph(conf2);
            net2.init();

            for( int j=0; j<3; j++ ){
                net.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5}));
            }

            for( int j=0; j<3; j++ ){
                net2.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5}));
            }
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:73,代码来源:WorkspaceTests.java

示例15: testSerializationCompGraph

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testSerializationCompGraph() throws Exception {

    for(WorkspaceMode wsm : WorkspaceMode.values()) {
        log.info("*** Starting workspace mode: " + wsm);

        Nd4j.getRandom().setSeed(12345);

        ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .trainingWorkspaceMode(wsm)
                .inferenceWorkspaceMode(wsm)
                .updater(new Adam())
                .graphBuilder()
                .addInputs("in")
                .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in")
                .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0")
                .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
                        .nIn(10).nOut(10).build(), "1")
                .setOutputs("2")
                .build();

        ComputationGraph net1 = new ComputationGraph(conf1);
        net1.init();

        INDArray in = Nd4j.rand(new int[]{3, 10, 5});
        INDArray labels = Nd4j.rand(new int[]{3, 10, 5});

        net1.fit(new DataSet(in, labels));

        byte[] bytes;
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
            ModelSerializer.writeModel(net1, baos, true);
            bytes = baos.toByteArray();
        }


        ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);


        in = Nd4j.rand(new int[]{3, 10, 5});
        labels = Nd4j.rand(new int[]{3, 10, 5});

        INDArray out1 = net1.outputSingle(in);
        INDArray out2 = net2.outputSingle(in);

        assertEquals(out1, out2);

        net1.setInput(0, in);
        net2.setInput(0, in);
        net1.setLabels(labels);
        net2.setLabels(labels);

        net1.computeGradientAndScore();
        net2.computeGradientAndScore();

        assertEquals(net1.score(), net2.score(), 1e-6);
        assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:62,代码来源:BidirectionalTest.java


注:本文中的org.deeplearning4j.nn.graph.ComputationGraph.init方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。