当前位置: 首页>>代码示例>>Java>>正文


Java DataSetIterator.getLabels方法代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.getLabels方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.getLabels方法的具体用法?Java DataSetIterator.getLabels怎么用?Java DataSetIterator.getLabels使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.dataset.api.iterator.DataSetIterator的用法示例。


在下文中一共展示了DataSetIterator.getLabels方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: evaluate

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy.
 * For 'standard' accuracy evaluation only, use topN = 1
 *
 * @param iterator   Iterator (data) to evaluate on
 * @param labelsList List of labels. May be null.
 * @param topN       N value for top N accuracy evaluation
 * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
 */
public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList, int topN) {
    if (layers == null || !(getOutputLayer() instanceof IOutputLayer)) {
        throw new IllegalStateException("Cannot evaluate network with no output layer");
    }
    if (labelsList == null)
        labelsList = iterator.getLabels();

    Evaluation e = new Evaluation(labelsList, topN);
    doEvaluation(iterator, e);

    return e;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:MultiLayerNetwork.java

示例2: main

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public static void main (String[] args) throws IOException {
    log.info("download and extract data...");
    CNNSentenceClassification.aclImdbDownloader(DATA_URL, DATA_PATH);

    // basic configuration
    int batchSize = 32;
    int vectorSize = 300;               //Size of the word vectors. 300 in the Google News model
    int nEpochs = 1;                    //Number of epochs (full passes of training data) to train on
    int truncateReviewsToLength = 256;  //Truncate reviews with length (# words) greater than this
    int cnnLayerFeatureMaps = 100;      //Number of feature maps / channels / depth for each CNN layer
    PoolingType globalPoolingType = PoolingType.MAX;
    Random rng = new Random(12345); //For shuffling repeatability

    log.info("construct cnn model...");
    ComputationGraph net = CNNSentenceClassification.buildCNNGraph(vectorSize, cnnLayerFeatureMaps, globalPoolingType);
    log.info("number of parameters by layer:");
    for (Layer l : net.getLayers()) {
        log.info("\t" + l.conf().getLayer().getLayerName() + "\t" + l.numParams());
    }

    // Load word vectors and get the DataSetIterators for training and testing
    log.info("loading word vectors and creating DataSetIterators...");
    WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
    DataSetIterator trainIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, true, wordVectors, batchSize,
            truncateReviewsToLength, rng);
    DataSetIterator testIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, false, wordVectors, batchSize,
            truncateReviewsToLength, rng);

    log.info("starting training...");
    for (int i = 0; i < nEpochs; i++) {
        net.fit(trainIter);
        log.info("Epoch " + i + " complete. Starting evaluation:");
        //Run evaluation. This is on 25k reviews, so can take some time
        Evaluation evaluation = net.evaluate(testIter);
        log.info(evaluation.stats());
    }

    // after training: load a single sentence and generate a prediction
    String pathFirstNegativeFile = FilenameUtils.concat(DATA_PATH, "aclImdb/test/neg/0_2.txt");
    String contentsFirstNegative = FileUtils.readFileToString(new File(pathFirstNegativeFile));
    INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)testIter).loadSingleSentence(contentsFirstNegative);
    INDArray predictionsFirstNegative = net.outputSingle(featuresFirstNegative);
    List<String> labels = testIter.getLabels();
    log.info("\n\nPredictions for first negative review:");
    for( int i=0; i<labels.size(); i++ ){
        log.info("P(" + labels.get(i) + ") = " + predictionsFirstNegative.getDouble(i));
    }
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:49,代码来源:DL4JCNNSentClassifyExample.java

示例3: evaluate

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy.
 * For 'standard' accuracy evaluation only, use topN = 1
 *
 * @param iterator   Iterator (data) to evaluate on
 * @param labelsList List of labels. May be null.
 * @param topN       N value for top N accuracy evaluation
 * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
 */
public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList, int topN) {
    if (labelsList == null)
        labelsList = iterator.getLabels();

    return doEvaluation(iterator, new Evaluation(labelsList, topN))[0];
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:16,代码来源:ComputationGraph.java


注:本文中的org.nd4j.linalg.dataset.api.iterator.DataSetIterator.getLabels方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。