当前位置: 首页>>代码示例>>Java>>正文


Java ComputationGraph.evaluate方法代码示例

本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.evaluate方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.evaluate方法的具体用法?Java ComputationGraph.evaluate怎么用?Java ComputationGraph.evaluate使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.nn.graph.ComputationGraph的用法示例。


在下文中一共展示了ComputationGraph.evaluate方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: holdout

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
/**
 * Perform simple holdout with a given percentage
 *
 * @param clf Classifier
 * @param data Full dataset
 * @param p Split percentage
 * @throws Exception
 */
public static void holdout(
    Dl4jMlpClassifier clf, Instances data, double p, AbstractInstanceIterator aii)
    throws Exception {

  holdout(clf, data, p);
  Instances[] split = splitTrainTest(data, p);

  Instances test = split[1];
  final DataSetIterator testIter = aii.getDataSetIterator(test, 42);
  final ComputationGraph model = clf.getModel();
  logger.info("DL4J Evaluation: ");
  org.deeplearning4j.eval.Evaluation evaluation = model.evaluate(testIter);
  logger.info(evaluation.stats());
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:23,代码来源:TestUtil.java

示例2: main

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
public static void main (String[] args) throws IOException {
    log.info("download and extract data...");
    CNNSentenceClassification.aclImdbDownloader(DATA_URL, DATA_PATH);

    // basic configuration
    int batchSize = 32;
    int vectorSize = 300;               //Size of the word vectors. 300 in the Google News model
    int nEpochs = 1;                    //Number of epochs (full passes of training data) to train on
    int truncateReviewsToLength = 256;  //Truncate reviews with length (# words) greater than this
    int cnnLayerFeatureMaps = 100;      //Number of feature maps / channels / depth for each CNN layer
    PoolingType globalPoolingType = PoolingType.MAX;
    Random rng = new Random(12345); //For shuffling repeatability

    log.info("construct cnn model...");
    ComputationGraph net = CNNSentenceClassification.buildCNNGraph(vectorSize, cnnLayerFeatureMaps, globalPoolingType);
    log.info("number of parameters by layer:");
    for (Layer l : net.getLayers()) {
        log.info("\t" + l.conf().getLayer().getLayerName() + "\t" + l.numParams());
    }

    // Load word vectors and get the DataSetIterators for training and testing
    log.info("loading word vectors and creating DataSetIterators...");
    WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
    DataSetIterator trainIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, true, wordVectors, batchSize,
            truncateReviewsToLength, rng);
    DataSetIterator testIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, false, wordVectors, batchSize,
            truncateReviewsToLength, rng);

    log.info("starting training...");
    for (int i = 0; i < nEpochs; i++) {
        net.fit(trainIter);
        log.info("Epoch " + i + " complete. Starting evaluation:");
        //Run evaluation. This is on 25k reviews, so can take some time
        Evaluation evaluation = net.evaluate(testIter);
        log.info(evaluation.stats());
    }

    // after training: load a single sentence and generate a prediction
    String pathFirstNegativeFile = FilenameUtils.concat(DATA_PATH, "aclImdb/test/neg/0_2.txt");
    String contentsFirstNegative = FileUtils.readFileToString(new File(pathFirstNegativeFile));
    INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)testIter).loadSingleSentence(contentsFirstNegative);
    INDArray predictionsFirstNegative = net.outputSingle(featuresFirstNegative);
    List<String> labels = testIter.getLabels();
    log.info("\n\nPredictions for first negative review:");
    for( int i=0; i<labels.size(); i++ ){
        log.info("P(" + labels.get(i) + ") = " + predictionsFirstNegative.getDouble(i));
    }
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:49,代码来源:DL4JCNNSentClassifyExample.java

示例3: testEvaluationAndRoc

import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testEvaluationAndRoc() {
    DataSetIterator iter = new IrisDataSetIterator(5, 150);

    //Make a 2-class version of iris:
    List<DataSet> l = new ArrayList<>();
    iter.reset();
    while (iter.hasNext()) {
        DataSet ds = iter.next();
        INDArray newL = Nd4j.create(ds.getLabels().size(0), 2);
        newL.putColumn(0, ds.getLabels().getColumn(0));
        newL.putColumn(1, ds.getLabels().getColumn(1));
        newL.getColumn(1).addi(ds.getLabels().getColumn(2));
        ds.setLabels(newL);
        l.add(ds);
    }

    iter = new ListDataSetIterator<>(l);

    ComputationGraph cg = getBasicNetIris2Class();

    Evaluation e = cg.evaluate(iter);
    ROC roc = cg.evaluateROC(iter, 32);


    SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null);



    JavaRDD<DataSet> rdd = sc.parallelize(l);
    rdd = rdd.repartition(20);

    Evaluation e2 = scg.evaluate(rdd);
    ROC roc2 = scg.evaluateROC(rdd);


    assertEquals(e2.accuracy(), e.accuracy(), 1e-3);
    assertEquals(e2.f1(), e.f1(), 1e-3);
    assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3);
    assertEquals(e2.falseNegatives(), e.falseNegatives());
    assertEquals(e2.falsePositives(), e.falsePositives());
    assertEquals(e2.trueNegatives(), e.trueNegatives());
    assertEquals(e2.truePositives(), e.truePositives());
    assertEquals(e2.precision(), e.precision(), 1e-3);
    assertEquals(e2.recall(), e.recall(), 1e-3);
    assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix());

    assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5);
    assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:51,代码来源:TestSparkComputationGraph.java


注:本文中的org.deeplearning4j.nn.graph.ComputationGraph.evaluate方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。