本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.setInput方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.setInput方法的具体用法?Java ComputationGraph.setInput怎么用?Java ComputationGraph.setInput使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.deeplearning4j.nn.graph.ComputationGraph
的用法示例。
在下文中一共展示了ComputationGraph.setInput方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: testLambdaConf
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testLambdaConf() {
double[] lambdas = new double[] {0.1, 0.01};
double[] results = new double[2];
int numClasses = 2;
INDArray input = Nd4j.rand(150, 4);
INDArray labels = Nd4j.zeros(150, numClasses);
Random r = new Random(12345);
for (int i = 0; i < 150; i++) {
labels.putScalar(i, r.nextInt(numClasses), 1.0);
}
ComputationGraph graph;
for (int i = 0; i < lambdas.length; i++) {
graph = getGraph(numClasses, lambdas[i]);
graph.setInput(0, input);
graph.setLabel(0, labels);
graph.computeGradientAndScore();
results[i] = graph.score();
}
assertNotEquals(results[0], results[1]);
}
示例2: testSerializationCompGraph
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testSerializationCompGraph() throws Exception {
for(WorkspaceMode wsm : WorkspaceMode.values()) {
log.info("*** Starting workspace mode: " + wsm);
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.updater(new Adam())
.graphBuilder()
.addInputs("in")
.layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in")
.layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0")
.layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)
.nIn(10).nOut(10).build(), "1")
.setOutputs("2")
.build();
ComputationGraph net1 = new ComputationGraph(conf1);
net1.init();
INDArray in = Nd4j.rand(new int[]{3, 10, 5});
INDArray labels = Nd4j.rand(new int[]{3, 10, 5});
net1.fit(new DataSet(in, labels));
byte[] bytes;
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
ModelSerializer.writeModel(net1, baos, true);
bytes = baos.toByteArray();
}
ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true);
in = Nd4j.rand(new int[]{3, 10, 5});
labels = Nd4j.rand(new int[]{3, 10, 5});
INDArray out1 = net1.outputSingle(in);
INDArray out2 = net2.outputSingle(in);
assertEquals(out1, out2);
net1.setInput(0, in);
net2.setInput(0, in);
net1.setLabels(labels);
net2.setLabels(labels);
net1.computeGradientAndScore();
net2.computeGradientAndScore();
assertEquals(net1.score(), net2.score(), 1e-6);
assertEquals(net1.gradient().gradient(), net2.gradient().gradient());
}
}