当前位置: 首页>>代码示例>>C++>>正文


C++ Net::eval方法代码示例

本文整理汇总了C++中Net::eval方法的典型用法代码示例。如果您正苦于以下问题:C++ Net::eval方法的具体用法?C++ Net::eval怎么用?C++ Net::eval使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在Net的用法示例。


在下文中一共展示了Net::eval方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。

示例1: test

void test(
    Net& model,
    torch::Device device,
    DataLoader& data_loader,
    size_t dataset_size) {
  torch::NoGradGuard no_grad;
  model.eval();
  double test_loss = 0;
  int32_t correct = 0;
  for (const auto& batch : data_loader) {
    auto data = batch.data.to(device), targets = batch.target.to(device);
    auto output = model.forward(data);
    test_loss += torch::nll_loss(
                     output,
                     targets,
                     /*weight=*/{},
                     Reduction::Sum)
                     .template item<float>();
    auto pred = output.argmax(1);
    correct += pred.eq(targets).sum().template item<int64_t>();
  }

  test_loss /= dataset_size;
  std::printf(
      "\nTest set: Average loss: %.4f | Accuracy: %.3f\n",
      test_loss,
      static_cast<double>(correct) / dataset_size);
}
开发者ID:chiminghui,项目名称:examples,代码行数:28,代码来源:mnist.cpp

示例2: checkParameter


//.........这里部分代码省略.........
                                    if (verbI > 5) {
                                        BufferToConsole ( "Done Training. Resampling...\n");
                                    }

                                    _PMathObj  tObj = _Constant(0).Time();
                                    _Parameter time1 = tObj->Value(),
                                               time2;

                                    while   (tIn.countitems() < checkSteps) {
                                        NNMatrixSampler     (0, vBounds, modelVariableList, variableMap, modelMatrix, tIn, tOut);
                                    }

                                    absError = 0.0;

                                    DeleteObject (tObj);
                                    tObj = _Constant(0).Time();
                                    time2 = tObj->Value();

                                    if (verbI > 5) {
                                        snprintf (buffer, sizeof(buffer),"Done Resampling in %g seconds. Computing Error...\n", time2-time1);
                                        BufferToConsole (buffer);
                                    }

                                    _Parameter  maxValT,
                                                maxValE;

                                    for (long verCount = 0; verCount < checkSteps; verCount++) {
                                        _Parameter*  inData     = ((_Matrix*)tIn(verCount))->theData,
                                                     *  outData  = ((_Matrix*)tOut(verCount))->theData;

                                        for (long cellCount = 0; cellCount < fullDimension; cellCount++) {
                                            Net         *thisCell = matrixNet[cellCount];

                                            _Parameter estVal     = thisCell->eval(inData)[0],
                                                       trueVal      = outData[cellCount],
                                                       localError;

                                            localError = estVal-trueVal;

                                            if (localError < 0) {
                                                localError = -localError;
                                            }

                                            if (absError < localError) {
                                                maxValT  = trueVal;
                                                maxValE  = estVal;
                                                absError = localError;
                                            }
                                        }
                                    }

                                    DeleteObject (tObj);
                                    tObj = _Constant(0).Time();
                                    time1 = tObj->Value();
                                    DeleteObject (tObj);


                                    if (verbI > 5) {
                                        snprintf (buffer, sizeof(buffer), "Done Error Checking in %g seconds. Got max abs error %g on the pair %g %g\n", time1-time2, absError, maxValT, maxValE);
                                        BufferToConsole (buffer);
                                    }
                                    if (absError <= errorTerm) {
                                        break;
                                    }
                                }
开发者ID:Dietermielke,项目名称:hyphy,代码行数:66,代码来源:HYNetInterface.cpp


注:本文中的Net::eval方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。