本文整理汇总了C++中Net::eval方法的典型用法代码示例。如果您正苦于以下问题:C++ Net::eval方法的具体用法?C++ Net::eval怎么用?C++ Net::eval使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类Net
的用法示例。
在下文中一共展示了Net::eval方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。
示例1: test
void test(
Net& model,
torch::Device device,
DataLoader& data_loader,
size_t dataset_size) {
torch::NoGradGuard no_grad;
model.eval();
double test_loss = 0;
int32_t correct = 0;
for (const auto& batch : data_loader) {
auto data = batch.data.to(device), targets = batch.target.to(device);
auto output = model.forward(data);
test_loss += torch::nll_loss(
output,
targets,
/*weight=*/{},
Reduction::Sum)
.template item<float>();
auto pred = output.argmax(1);
correct += pred.eq(targets).sum().template item<int64_t>();
}
test_loss /= dataset_size;
std::printf(
"\nTest set: Average loss: %.4f | Accuracy: %.3f\n",
test_loss,
static_cast<double>(correct) / dataset_size);
}
示例2: checkParameter
//.........这里部分代码省略.........
if (verbI > 5) {
BufferToConsole ( "Done Training. Resampling...\n");
}
_PMathObj tObj = _Constant(0).Time();
_Parameter time1 = tObj->Value(),
time2;
while (tIn.countitems() < checkSteps) {
NNMatrixSampler (0, vBounds, modelVariableList, variableMap, modelMatrix, tIn, tOut);
}
absError = 0.0;
DeleteObject (tObj);
tObj = _Constant(0).Time();
time2 = tObj->Value();
if (verbI > 5) {
snprintf (buffer, sizeof(buffer),"Done Resampling in %g seconds. Computing Error...\n", time2-time1);
BufferToConsole (buffer);
}
_Parameter maxValT,
maxValE;
for (long verCount = 0; verCount < checkSteps; verCount++) {
_Parameter* inData = ((_Matrix*)tIn(verCount))->theData,
* outData = ((_Matrix*)tOut(verCount))->theData;
for (long cellCount = 0; cellCount < fullDimension; cellCount++) {
Net *thisCell = matrixNet[cellCount];
_Parameter estVal = thisCell->eval(inData)[0],
trueVal = outData[cellCount],
localError;
localError = estVal-trueVal;
if (localError < 0) {
localError = -localError;
}
if (absError < localError) {
maxValT = trueVal;
maxValE = estVal;
absError = localError;
}
}
}
DeleteObject (tObj);
tObj = _Constant(0).Time();
time1 = tObj->Value();
DeleteObject (tObj);
if (verbI > 5) {
snprintf (buffer, sizeof(buffer), "Done Error Checking in %g seconds. Got max abs error %g on the pair %g %g\n", time1-time2, absError, maxValT, maxValE);
BufferToConsole (buffer);
}
if (absError <= errorTerm) {
break;
}
}