当前位置: 首页>>代码示例>>Java>>正文


Java ComputationGraph.scoreExamples方法代码示例

本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.scoreExamples方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.scoreExamples方法的具体用法?Java ComputationGraph.scoreExamples怎么用?Java ComputationGraph.scoreExamples使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.nn.graph.ComputationGraph的用法示例。


在下文中一共展示了ComputationGraph.scoreExamples方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: call

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, MultiDataSet>> iterator) throws Exception {
    if (!iterator.hasNext()) {
        return Collections.emptyList();
    }

    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);

    List<Tuple2<K, Double>> ret = new ArrayList<>();

    List<MultiDataSet> collect = new ArrayList<>(batchSize);
    List<K> collectKey = new ArrayList<>(batchSize);
    int totalCount = 0;
    while (iterator.hasNext()) {
        collect.clear();
        collectKey.clear();
        int nExamples = 0;
        while (iterator.hasNext() && nExamples < batchSize) {
            Tuple2<K, MultiDataSet> t2 = iterator.next();
            MultiDataSet ds = t2._2();
            int n = ds.getFeatures(0).size(0);
            if (n != 1)
                throw new IllegalStateException("Cannot score examples with one key per data set if "
                                + "data set contains more than 1 example (numExamples: " + n + ")");
            collect.add(ds);
            collectKey.add(t2._1());
            nExamples += n;
        }
        totalCount += nExamples;

        MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);


        INDArray scores = network.scoreExamples(data, addRegularization);
        double[] doubleScores = scores.data().asDouble();

        for (int i = 0; i < doubleScores.length; i++) {
            ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
        }
    }

    Nd4j.getExecutioner().commit();

    if (log.isDebugEnabled()) {
        log.debug("Scored {} examples ", totalCount);
    }

    return ret;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:56,代码来源:ScoreExamplesWithKeyFunction.java

示例2: call

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Override
public Iterable<Double> call(Iterator<MultiDataSet> iterator) throws Exception {
    if (!iterator.hasNext()) {
        return Collections.emptyList();
    }

    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
    network.init();
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException(
                        "Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);

    List<Double> ret = new ArrayList<>();

    List<MultiDataSet> collect = new ArrayList<>(batchSize);
    int totalCount = 0;
    while (iterator.hasNext()) {
        collect.clear();
        int nExamples = 0;
        while (iterator.hasNext() && nExamples < batchSize) {
            MultiDataSet ds = iterator.next();
            int n = ds.getFeatures(0).size(0);
            collect.add(ds);
            nExamples += n;
        }
        totalCount += nExamples;


        MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);


        INDArray scores = network.scoreExamples(data, addRegularization);
        double[] doubleScores = scores.data().asDouble();

        for (double doubleScore : doubleScores) {
            ret.add(doubleScore);
        }
    }

    Nd4j.getExecutioner().commit();

    if (log.isDebugEnabled()) {
        log.debug("Scored {} examples ", totalCount);
    }

    return ret;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:50,代码来源:ScoreExamplesFunction.java

示例3: testDistributedScoring

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testDistributedScoring() {

    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1)
                    .seed(123).updater(new Nesterovs(0.1, 0.9)).graphBuilder()
                    .addInputs("in")
                    .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3)
                                    .activation(Activation.TANH).build(), "in")
                    .addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
                                    LossFunctions.LossFunction.MCXENT).nIn(3).nOut(nOut)
                                                    .activation(Activation.SOFTMAX).build(),
                                    "0")
                    .setOutputs("1").backprop(true).pretrain(false).build();

    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);

    SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
    ComputationGraph netCopy = sparkNet.getNetwork().clone();

    int nRows = 100;

    INDArray features = Nd4j.rand(nRows, nIn);
    INDArray labels = Nd4j.zeros(nRows, nOut);
    Random r = new Random(12345);
    for (int i = 0; i < nRows; i++) {
        labels.putScalar(new int[] {i, r.nextInt(nOut)}, 1.0);
    }

    INDArray localScoresWithReg = netCopy.scoreExamples(new DataSet(features, labels), true);
    INDArray localScoresNoReg = netCopy.scoreExamples(new DataSet(features, labels), false);

    List<Tuple2<String, DataSet>> dataWithKeys = new ArrayList<>();
    for (int i = 0; i < nRows; i++) {
        DataSet ds = new DataSet(features.getRow(i).dup(), labels.getRow(i).dup());
        dataWithKeys.add(new Tuple2<>(String.valueOf(i), ds));
    }
    JavaPairRDD<String, DataSet> dataWithKeysRdd = sc.parallelizePairs(dataWithKeys);

    JavaPairRDD<String, Double> sparkScoresWithReg = sparkNet.scoreExamples(dataWithKeysRdd, true, 4);
    JavaPairRDD<String, Double> sparkScoresNoReg = sparkNet.scoreExamples(dataWithKeysRdd, false, 4);

    Map<String, Double> sparkScoresWithRegMap = sparkScoresWithReg.collectAsMap();
    Map<String, Double> sparkScoresNoRegMap = sparkScoresNoReg.collectAsMap();

    for (int i = 0; i < nRows; i++) {
        double scoreRegExp = localScoresWithReg.getDouble(i);
        double scoreRegAct = sparkScoresWithRegMap.get(String.valueOf(i));
        assertEquals(scoreRegExp, scoreRegAct, 1e-5);

        double scoreNoRegExp = localScoresNoReg.getDouble(i);
        double scoreNoRegAct = sparkScoresNoRegMap.get(String.valueOf(i));
        assertEquals(scoreNoRegExp, scoreNoRegAct, 1e-5);

        //            System.out.println(scoreRegExp + "\t" + scoreRegAct + "\t" + scoreNoRegExp + "\t" + scoreNoRegAct);
    }

    List<DataSet> dataNoKeys = new ArrayList<>();
    for (int i = 0; i < nRows; i++) {
        dataNoKeys.add(new DataSet(features.getRow(i).dup(), labels.getRow(i).dup()));
    }
    JavaRDD<DataSet> dataNoKeysRdd = sc.parallelize(dataNoKeys);

    List<Double> scoresWithReg = new ArrayList<>(sparkNet.scoreExamples(dataNoKeysRdd, true, 4).collect());
    List<Double> scoresNoReg = new ArrayList<>(sparkNet.scoreExamples(dataNoKeysRdd, false, 4).collect());
    Collections.sort(scoresWithReg);
    Collections.sort(scoresNoReg);
    double[] localScoresWithRegDouble = localScoresWithReg.data().asDouble();
    double[] localScoresNoRegDouble = localScoresNoReg.data().asDouble();
    Arrays.sort(localScoresWithRegDouble);
    Arrays.sort(localScoresNoRegDouble);

    for (int i = 0; i < localScoresWithRegDouble.length; i++) {
        assertEquals(localScoresWithRegDouble[i], scoresWithReg.get(i), 1e-5);
        assertEquals(localScoresNoRegDouble[i], scoresNoReg.get(i), 1e-5);

        //            System.out.println(localScoresWithRegDouble[i] + "\t" + scoresWithReg.get(i) + "\t" + localScoresNoRegDouble[i] + "\t" + scoresNoReg.get(i));
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:79,代码来源:TestSparkComputationGraph.java


注:本文中的org.deeplearning4j.nn.graph.ComputationGraph.scoreExamples方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。