本文整理汇总了Java中org.deeplearning4j.nn.graph.ComputationGraph.evaluate方法的典型用法代码示例。如果您正苦于以下问题:Java ComputationGraph.evaluate方法的具体用法?Java ComputationGraph.evaluate怎么用?Java ComputationGraph.evaluate使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.deeplearning4j.nn.graph.ComputationGraph
的用法示例。
在下文中一共展示了ComputationGraph.evaluate方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: holdout
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
/**
* Perform simple holdout with a given percentage
*
* @param clf Classifier
* @param data Full dataset
* @param p Split percentage
* @throws Exception
*/
public static void holdout(
Dl4jMlpClassifier clf, Instances data, double p, AbstractInstanceIterator aii)
throws Exception {
holdout(clf, data, p);
Instances[] split = splitTrainTest(data, p);
Instances test = split[1];
final DataSetIterator testIter = aii.getDataSetIterator(test, 42);
final ComputationGraph model = clf.getModel();
logger.info("DL4J Evaluation: ");
org.deeplearning4j.eval.Evaluation evaluation = model.evaluate(testIter);
logger.info(evaluation.stats());
}
示例2: main
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
public static void main (String[] args) throws IOException {
log.info("download and extract data...");
CNNSentenceClassification.aclImdbDownloader(DATA_URL, DATA_PATH);
// basic configuration
int batchSize = 32;
int vectorSize = 300; //Size of the word vectors. 300 in the Google News model
int nEpochs = 1; //Number of epochs (full passes of training data) to train on
int truncateReviewsToLength = 256; //Truncate reviews with length (# words) greater than this
int cnnLayerFeatureMaps = 100; //Number of feature maps / channels / depth for each CNN layer
PoolingType globalPoolingType = PoolingType.MAX;
Random rng = new Random(12345); //For shuffling repeatability
log.info("construct cnn model...");
ComputationGraph net = CNNSentenceClassification.buildCNNGraph(vectorSize, cnnLayerFeatureMaps, globalPoolingType);
log.info("number of parameters by layer:");
for (Layer l : net.getLayers()) {
log.info("\t" + l.conf().getLayer().getLayerName() + "\t" + l.numParams());
}
// Load word vectors and get the DataSetIterators for training and testing
log.info("loading word vectors and creating DataSetIterators...");
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
DataSetIterator trainIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, true, wordVectors, batchSize,
truncateReviewsToLength, rng);
DataSetIterator testIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, false, wordVectors, batchSize,
truncateReviewsToLength, rng);
log.info("starting training...");
for (int i = 0; i < nEpochs; i++) {
net.fit(trainIter);
log.info("Epoch " + i + " complete. Starting evaluation:");
//Run evaluation. This is on 25k reviews, so can take some time
Evaluation evaluation = net.evaluate(testIter);
log.info(evaluation.stats());
}
// after training: load a single sentence and generate a prediction
String pathFirstNegativeFile = FilenameUtils.concat(DATA_PATH, "aclImdb/test/neg/0_2.txt");
String contentsFirstNegative = FileUtils.readFileToString(new File(pathFirstNegativeFile));
INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)testIter).loadSingleSentence(contentsFirstNegative);
INDArray predictionsFirstNegative = net.outputSingle(featuresFirstNegative);
List<String> labels = testIter.getLabels();
log.info("\n\nPredictions for first negative review:");
for( int i=0; i<labels.size(); i++ ){
log.info("P(" + labels.get(i) + ") = " + predictionsFirstNegative.getDouble(i));
}
}
示例3: testEvaluationAndRoc
import org.deeplearning4j.nn.graph.ComputationGraph; //导入方法依赖的package包/类
@Test
public void testEvaluationAndRoc() {
DataSetIterator iter = new IrisDataSetIterator(5, 150);
//Make a 2-class version of iris:
List<DataSet> l = new ArrayList<>();
iter.reset();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray newL = Nd4j.create(ds.getLabels().size(0), 2);
newL.putColumn(0, ds.getLabels().getColumn(0));
newL.putColumn(1, ds.getLabels().getColumn(1));
newL.getColumn(1).addi(ds.getLabels().getColumn(2));
ds.setLabels(newL);
l.add(ds);
}
iter = new ListDataSetIterator<>(l);
ComputationGraph cg = getBasicNetIris2Class();
Evaluation e = cg.evaluate(iter);
ROC roc = cg.evaluateROC(iter, 32);
SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null);
JavaRDD<DataSet> rdd = sc.parallelize(l);
rdd = rdd.repartition(20);
Evaluation e2 = scg.evaluate(rdd);
ROC roc2 = scg.evaluateROC(rdd);
assertEquals(e2.accuracy(), e.accuracy(), 1e-3);
assertEquals(e2.f1(), e.f1(), 1e-3);
assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3);
assertEquals(e2.falseNegatives(), e.falseNegatives());
assertEquals(e2.falsePositives(), e.falsePositives());
assertEquals(e2.trueNegatives(), e.trueNegatives());
assertEquals(e2.truePositives(), e.truePositives());
assertEquals(e2.precision(), e.precision(), 1e-3);
assertEquals(e2.recall(), e.recall(), 1e-3);
assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix());
assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5);
assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5);
}