本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.scoreExamples方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.scoreExamples方法的具体用法?Java ComputationGraph.scoreExamples怎么用?Java ComputationGraph.scoreExamples使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.deeplearning4j.nn.graph.ComputationGraph
的用法示例。
在下文中一共展示了ComputationGraph.scoreExamples方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: call
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, MultiDataSet>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException(
"Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Tuple2<K, Double>> ret = new ArrayList<>();
List<MultiDataSet> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, MultiDataSet> t2 = iterator.next();
MultiDataSet ds = t2._2();
int n = ds.getFeatures(0).size(0);
if (n != 1)
throw new IllegalStateException("Cannot score examples with one key per data set if "
+ "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(ds);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;
MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}
Nd4j.getExecutioner().commit();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
示例2: call
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Override
public Iterable<Double> call(Iterator<MultiDataSet> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException(
"Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Double> ret = new ArrayList<>();
List<MultiDataSet> collect = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
MultiDataSet ds = iterator.next();
int n = ds.getFeatures(0).size(0);
collect.add(ds);
nExamples += n;
}
totalCount += nExamples;
MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (double doubleScore : doubleScores) {
ret.add(doubleScore);
}
}
Nd4j.getExecutioner().commit();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
示例3: testDistributedScoring
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testDistributedScoring() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1)
.seed(123).updater(new Nesterovs(0.1, 0.9)).graphBuilder()
.addInputs("in")
.addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3)
.activation(Activation.TANH).build(), "in")
.addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(3).nOut(nOut)
.activation(Activation.SOFTMAX).build(),
"0")
.setOutputs("1").backprop(true).pretrain(false).build();
TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
ComputationGraph netCopy = sparkNet.getNetwork().clone();
int nRows = 100;
INDArray features = Nd4j.rand(nRows, nIn);
INDArray labels = Nd4j.zeros(nRows, nOut);
Random r = new Random(12345);
for (int i = 0; i < nRows; i++) {
labels.putScalar(new int[] {i, r.nextInt(nOut)}, 1.0);
}
INDArray localScoresWithReg = netCopy.scoreExamples(new DataSet(features, labels), true);
INDArray localScoresNoReg = netCopy.scoreExamples(new DataSet(features, labels), false);
List<Tuple2<String, DataSet>> dataWithKeys = new ArrayList<>();
for (int i = 0; i < nRows; i++) {
DataSet ds = new DataSet(features.getRow(i).dup(), labels.getRow(i).dup());
dataWithKeys.add(new Tuple2<>(String.valueOf(i), ds));
}
JavaPairRDD<String, DataSet> dataWithKeysRdd = sc.parallelizePairs(dataWithKeys);
JavaPairRDD<String, Double> sparkScoresWithReg = sparkNet.scoreExamples(dataWithKeysRdd, true, 4);
JavaPairRDD<String, Double> sparkScoresNoReg = sparkNet.scoreExamples(dataWithKeysRdd, false, 4);
Map<String, Double> sparkScoresWithRegMap = sparkScoresWithReg.collectAsMap();
Map<String, Double> sparkScoresNoRegMap = sparkScoresNoReg.collectAsMap();
for (int i = 0; i < nRows; i++) {
double scoreRegExp = localScoresWithReg.getDouble(i);
double scoreRegAct = sparkScoresWithRegMap.get(String.valueOf(i));
assertEquals(scoreRegExp, scoreRegAct, 1e-5);
double scoreNoRegExp = localScoresNoReg.getDouble(i);
double scoreNoRegAct = sparkScoresNoRegMap.get(String.valueOf(i));
assertEquals(scoreNoRegExp, scoreNoRegAct, 1e-5);
// System.out.println(scoreRegExp + "\t" + scoreRegAct + "\t" + scoreNoRegExp + "\t" + scoreNoRegAct);
}
List<DataSet> dataNoKeys = new ArrayList<>();
for (int i = 0; i < nRows; i++) {
dataNoKeys.add(new DataSet(features.getRow(i).dup(), labels.getRow(i).dup()));
}
JavaRDD<DataSet> dataNoKeysRdd = sc.parallelize(dataNoKeys);
List<Double> scoresWithReg = new ArrayList<>(sparkNet.scoreExamples(dataNoKeysRdd, true, 4).collect());
List<Double> scoresNoReg = new ArrayList<>(sparkNet.scoreExamples(dataNoKeysRdd, false, 4).collect());
Collections.sort(scoresWithReg);
Collections.sort(scoresNoReg);
double[] localScoresWithRegDouble = localScoresWithReg.data().asDouble();
double[] localScoresNoRegDouble = localScoresNoReg.data().asDouble();
Arrays.sort(localScoresWithRegDouble);
Arrays.sort(localScoresNoRegDouble);
for (int i = 0; i < localScoresWithRegDouble.length; i++) {
assertEquals(localScoresWithRegDouble[i], scoresWithReg.get(i), 1e-5);
assertEquals(localScoresNoRegDouble[i], scoresNoReg.get(i), 1e-5);
// System.out.println(localScoresWithRegDouble[i] + "\t" + scoresWithReg.get(i) + "\t" + localScoresNoRegDouble[i] + "\t" + scoresNoReg.get(i));
}
}