当前位置: 首页>>代码示例>>Java>>正文


Java Evaluation类代码示例

本文整理汇总了Java中org.deeplearning4j.eval.Evaluation的典型用法代码示例。如果您正苦于以下问题:Java Evaluation类的具体用法?Java Evaluation怎么用?Java Evaluation使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


Evaluation类属于org.deeplearning4j.eval包,在下文中一共展示了Evaluation类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testEvaluation

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@Test
public void testEvaluation(){

    SparkDl4jMultiLayer sparkNet = getBasicNetwork();
    MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();

    Evaluation evalExpected = new Evaluation();
    INDArray outLocal = netCopy.output(input, Layer.TrainingMode.TEST);
    evalExpected.eval(labels, outLocal);

    Evaluation evalActual = sparkNet.evaluate(sparkData);

    assertEquals(evalExpected.accuracy(), evalActual.accuracy(), 1e-3);
    assertEquals(evalExpected.f1(), evalActual.f1(), 1e-3);
    assertEquals(evalExpected.getNumRowCounter(), evalActual.getNumRowCounter(), 1e-3);
    assertMapEquals(evalExpected.falseNegatives(),evalActual.falseNegatives());
    assertMapEquals(evalExpected.falsePositives(), evalActual.falsePositives());
    assertMapEquals(evalExpected.trueNegatives(), evalActual.trueNegatives());
    assertMapEquals(evalExpected.truePositives(),evalActual.truePositives());
    assertEquals(evalExpected.precision(), evalActual.precision(), 1e-3);
    assertEquals(evalExpected.recall(), evalActual.recall(), 1e-3);
    assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
}
 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-Hadoop,代码行数:24,代码来源:TestSparkMultiLayerParameterAveraging.java

示例2: evalMnistTestSet

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
private static void evalMnistTestSet(MultiLayerNetwork leNetModel) throws Exception {
	
       log.info("Load test data....");
       int batchSize = 64;
       DataSetIterator mnistTest = new MnistDataSetIterator(batchSize,false,12345);
	
       log.info("Evaluate model....");
       int outputNum = 10;
       Evaluation eval = new Evaluation(outputNum);
	
       while(mnistTest.hasNext()){
           DataSet dataSet = mnistTest.next();
           INDArray output = leNetModel.output(dataSet.getFeatureMatrix(), false);
           eval.eval(dataSet.getLabels(), output);
       }
	
       log.info(eval.stats());
}
 
开发者ID:matthiaszimmermann,项目名称:ml_demo,代码行数:19,代码来源:LeNetMnistTester.java

示例3: evaluate

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@SuppressWarnings("rawtypes")
   public DeepBeliefNetworkModel evaluate()
   {
final DataSet testingData = ((IrisData) data).getTestingData();

final Evaluation evaluation = new Evaluation(parameters.getOutputSize());
for (int j = 0; j < 2; j++)
{
    final INDArray output = model.output(testingData.getFeatureMatrix(), Layer.TrainingMode.TEST);

    for (int i = 0; i < output.rows(); i++)
    {
	String actual = testingData.getLabels().getRow(i).toString().trim();
	String predicted = output.getRow(i).toString().trim();
	System.out.println("actual " + actual + " vs predicted " + predicted);
    }

    evaluation.eval(testingData.getLabels(), output);
    System.out.println(evaluation.stats());
}
return this;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:23,代码来源:DeepBeliefNetworkModel.java

示例4: evaluate

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@Override
   @SuppressWarnings("rawtypes")
   public Model evaluate()
   {
final Evaluation evaluation = new Evaluation(parameters.getOutputSize());
try
{
    final DataSetIterator iterator = new MnistDataSetIterator(100, 10000);
    while (iterator.hasNext())
    {
	final DataSet testingData = iterator.next();
	evaluation.eval(testingData.getLabels(), model.output(testingData.getFeatureMatrix()));
    }

    System.out.println(evaluation.stats());
}
catch (IOException e)
{
    e.printStackTrace();
}
return this;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:23,代码来源:StackedAutoEncoderModel.java

示例5: evaluate

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@Override
   @SuppressWarnings("rawtypes")
   public Model evaluate()
   {
final List<INDArray> testingFeatures = ((MnistData) data).getTestingFeatures();
final List<INDArray> testingLabels = ((MnistData) data).getTestingLabels();
final Evaluation evaluation = new Evaluation(parameters.getOutputSize());
for (int i = 0; i < testingFeatures.size(); i++)
{
    evaluation.eval(testingLabels.get(i), model.output(testingFeatures.get(i)));
}
// evaluation.eval(testingLabels.get(0),
// model.output(testingFeatures.get(0)));
System.out.println(evaluation.stats());
return this;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:17,代码来源:ConvolutionalNetModel.java

示例6: main

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: sentimentモデル名
 * args[2] input: test親フォルダ名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);

  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
  Evaluation evaluation = new Evaluation();
  while(test.hasNext()) {
    DataSet t = test.next();
    INDArray features = t.getFeatures();
    INDArray lables = t.getLabels();
    INDArray inMask = t.getFeaturesMaskArray();
    INDArray outMask = t.getLabelsMaskArray();
    INDArray predicted = model.output(features,false,inMask,outMask);
    evaluation.evalTimeSeries(lables,predicted,outMask);
  }
  System.out.println(evaluation.stats());
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:30,代码来源:SentimentRecurrentTestCmd.java

示例7: testMLPMultiLayerBackprop

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@Test
public void testMLPMultiLayerBackprop() {
    MultiLayerNetwork model = getDenseMLNConfig(true, false);
    model.fit(iter);

    MultiLayerNetwork model2 = getDenseMLNConfig(true, false);
    model2.fit(iter);
    iter.reset();

    DataSet test = iter.next();

    assertEquals(model.params(), model2.params());

    Evaluation eval = new Evaluation();
    INDArray output = model.output(test.getFeatureMatrix());
    eval.eval(test.getLabels(), output);
    double f1Score = eval.f1();

    Evaluation eval2 = new Evaluation();
    INDArray output2 = model2.output(test.getFeatureMatrix());
    eval2.eval(test.getLabels(), output2);
    double f1Score2 = eval2.f1();

    assertEquals(f1Score, f1Score2, 1e-4);

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:DenseTest.java

示例8: testCGEvaluation

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@Test
public void testCGEvaluation() {

    Nd4j.getRandom().setSeed(12345);
    ComputationGraphConfiguration configuration = getIrisGraphConfiguration();
    ComputationGraph graph = new ComputationGraph(configuration);
    graph.init();

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration mlnConfig = getIrisMLNConfiguration();
    MultiLayerNetwork net = new MultiLayerNetwork(mlnConfig);
    net.init();

    DataSetIterator iris = new IrisDataSetIterator(75, 150);

    net.fit(iris);
    iris.reset();
    graph.fit(iris);

    iris.reset();
    Evaluation evalExpected = net.evaluate(iris);
    iris.reset();
    Evaluation evalActual = graph.evaluate(iris);

    assertEquals(evalExpected.accuracy(), evalActual.accuracy(), 0e-4);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:TestComputationGraphNetwork.java

示例9: testEvaluation

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@Test
public void testEvaluation() {

    SparkDl4jMultiLayer sparkNet = getBasicNetwork();
    MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();

    Evaluation evalExpected = new Evaluation();
    INDArray outLocal = netCopy.output(input, Layer.TrainingMode.TEST);
    evalExpected.eval(labels, outLocal);

    Evaluation evalActual = sparkNet.evaluate(sparkData);

    assertEquals(evalExpected.accuracy(), evalActual.accuracy(), 1e-3);
    assertEquals(evalExpected.f1(), evalActual.f1(), 1e-3);
    assertEquals(evalExpected.getNumRowCounter(), evalActual.getNumRowCounter(), 1e-3);
    assertMapEquals(evalExpected.falseNegatives(), evalActual.falseNegatives());
    assertMapEquals(evalExpected.falsePositives(), evalActual.falsePositives());
    assertMapEquals(evalExpected.trueNegatives(), evalActual.trueNegatives());
    assertMapEquals(evalExpected.truePositives(), evalActual.truePositives());
    assertEquals(evalExpected.precision(), evalActual.precision(), 1e-3);
    assertEquals(evalExpected.recall(), evalActual.recall(), 1e-3);
    assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:24,代码来源:TestSparkMultiLayerParameterAveraging.java

示例10: evaluate

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
    //evaluate the model on the test set
    DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
    double score = model.score(testData);
    Evaluation eval = new Evaluation(numClasses);
    INDArray output = model.output(testData.getFeatureMatrix());
    eval.eval(testData.getLabels(), output);
    return "Score: " + score;
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:11,代码来源:IrisModel.java

示例11: evaluate

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@Override
public String evaluate(FederatedDataSet federatedDataSet) {
    DataSet testData = (DataSet) federatedDataSet.getNativeDataSet();
    List<DataSet> listDs = testData.asList();
    DataSetIterator iterator = new ListDataSetIterator(listDs, BATCH_SIZE);

    Evaluation eval = new Evaluation(OUTPUT_NUM); //create an evaluation object with 10 possible classes
    while (iterator.hasNext()) {
        DataSet next = iterator.next();
        INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction
        eval.eval(next.getLabels(), output); //check the prediction against the true class
    }

    return eval.stats();
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:16,代码来源:MNISTModel.java

示例12: evaluate

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
public void evaluate() {
	log.info("Evaluate model....");
	Evaluation eval = new Evaluation(ConfigurationFactory.NUM_OUTPUTS);
	while (m_testSet.hasNext()) {
		DataSet ds = m_testSet.next();
		INDArray output = m_model.output(ds.getFeatureMatrix(), false);
		eval.eval(ds.getLabels(), output);
	}
	log.info(eval.stats());
	m_testSet.reset();
}
 
开发者ID:braeunlich,项目名称:anagnostes,代码行数:12,代码来源:NeuralNetwork.java

示例13: toNetworkStatisticsResource

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
private NetworkStatisticsResource toNetworkStatisticsResource(Evaluation evaluation) {
    NetworkStatisticsResource resource = new NetworkStatisticsResource();
    resource.setAccuracy(evaluation.accuracy());
    resource.setF1(evaluation.f1());
    resource.setPrecision(evaluation.precision());
    resource.setRecall(evaluation.recall());
    return resource;
}
 
开发者ID:scaliby,项目名称:ceidg-captcha,代码行数:9,代码来源:MachineLearningServiceImpl.java

示例14: evaluate

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
@Override
   @SuppressWarnings("rawtypes")
   public Model evaluate()
   {
final DataSet testingData = ((IrisData) data).getTestingData();
final Evaluation evaluation = new Evaluation(parameters.getOutputSize());
evaluation.eval(testingData.getLabels(), model.output(testingData.getFeatureMatrix()));
System.out.println(evaluation.stats());
return this;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:11,代码来源:ConvolutionalNetModel.java

示例15: main

import org.deeplearning4j.eval.Evaluation; //导入依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: 学習モデル名
 * args[2] input: train/test親フォルダ名
 * args[3] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 1;

  System.out.println("Starting online training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    System.out.println("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    System.out.println(evaluation.stats());

    System.out.println("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:47,代码来源:SentimentRecurrentTrainOnlineCmd.java


注:本文中的org.deeplearning4j.eval.Evaluation类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。