本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.getLabels方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.getLabels方法的具体用法?Java DataSetIterator.getLabels怎么用?Java DataSetIterator.getLabels使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.dataset.api.iterator.DataSetIterator
的用法示例。
在下文中一共展示了DataSetIterator.getLabels方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: evaluate
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy.
* For 'standard' accuracy evaluation only, use topN = 1
*
* @param iterator Iterator (data) to evaluate on
* @param labelsList List of labels. May be null.
* @param topN N value for top N accuracy evaluation
* @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
*/
public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList, int topN) {
if (layers == null || !(getOutputLayer() instanceof IOutputLayer)) {
throw new IllegalStateException("Cannot evaluate network with no output layer");
}
if (labelsList == null)
labelsList = iterator.getLabels();
Evaluation e = new Evaluation(labelsList, topN);
doEvaluation(iterator, e);
return e;
}
示例2: main
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public static void main (String[] args) throws IOException {
log.info("download and extract data...");
CNNSentenceClassification.aclImdbDownloader(DATA_URL, DATA_PATH);
// basic configuration
int batchSize = 32;
int vectorSize = 300; //Size of the word vectors. 300 in the Google News model
int nEpochs = 1; //Number of epochs (full passes of training data) to train on
int truncateReviewsToLength = 256; //Truncate reviews with length (# words) greater than this
int cnnLayerFeatureMaps = 100; //Number of feature maps / channels / depth for each CNN layer
PoolingType globalPoolingType = PoolingType.MAX;
Random rng = new Random(12345); //For shuffling repeatability
log.info("construct cnn model...");
ComputationGraph net = CNNSentenceClassification.buildCNNGraph(vectorSize, cnnLayerFeatureMaps, globalPoolingType);
log.info("number of parameters by layer:");
for (Layer l : net.getLayers()) {
log.info("\t" + l.conf().getLayer().getLayerName() + "\t" + l.numParams());
}
// Load word vectors and get the DataSetIterators for training and testing
log.info("loading word vectors and creating DataSetIterators...");
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
DataSetIterator trainIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, true, wordVectors, batchSize,
truncateReviewsToLength, rng);
DataSetIterator testIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, false, wordVectors, batchSize,
truncateReviewsToLength, rng);
log.info("starting training...");
for (int i = 0; i < nEpochs; i++) {
net.fit(trainIter);
log.info("Epoch " + i + " complete. Starting evaluation:");
//Run evaluation. This is on 25k reviews, so can take some time
Evaluation evaluation = net.evaluate(testIter);
log.info(evaluation.stats());
}
// after training: load a single sentence and generate a prediction
String pathFirstNegativeFile = FilenameUtils.concat(DATA_PATH, "aclImdb/test/neg/0_2.txt");
String contentsFirstNegative = FileUtils.readFileToString(new File(pathFirstNegativeFile));
INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)testIter).loadSingleSentence(contentsFirstNegative);
INDArray predictionsFirstNegative = net.outputSingle(featuresFirstNegative);
List<String> labels = testIter.getLabels();
log.info("\n\nPredictions for first negative review:");
for( int i=0; i<labels.size(); i++ ){
log.info("P(" + labels.get(i) + ") = " + predictionsFirstNegative.getDouble(i));
}
}
示例3: evaluate
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy.
* For 'standard' accuracy evaluation only, use topN = 1
*
* @param iterator Iterator (data) to evaluate on
* @param labelsList List of labels. May be null.
* @param topN N value for top N accuracy evaluation
* @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator
*/
public Evaluation evaluate(DataSetIterator iterator, List<String> labelsList, int topN) {
if (labelsList == null)
labelsList = iterator.getLabels();
return doEvaluation(iterator, new Evaluation(labelsList, topN))[0];
}