本文整理汇总了C++中utils::ConfusionMatrix方法的典型用法代码示例。如果您正苦于以下问题:C++ utils::ConfusionMatrix方法的具体用法?C++ utils::ConfusionMatrix怎么用?C++ utils::ConfusionMatrix使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类utils
的用法示例。
在下文中一共展示了utils::ConfusionMatrix方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。
示例1: average_error
REAL_t average_error(const vector<model_t>& models, const vector<vector<Databatch>>& validation_sets) {
atomic<int> correct(0);
int set_size = 10;
ReportProgress<double> journalist("Average error", validation_sets.size() * set_size);
atomic<int> total(0);
atomic<int> seen_minibatches(0);
auto confusion = ConfusionMatrix(5, SST::label_names);
for (int validation_set_num = 0; validation_set_num < validation_sets.size(); validation_set_num++) {
auto random_batch_order = utils::random_arange(validation_sets[validation_set_num].size());
if (random_batch_order.size() > set_size)
random_batch_order.resize(set_size);
for (auto& minibatch_num : random_batch_order) {
pool->run([&confusion, &journalist, &models, &correct,&seen_minibatches, &total, &validation_sets, validation_set_num, minibatch_num] {
auto& valid_set = validation_sets[validation_set_num][minibatch_num];
vector<mat> probs;
for (int k = 0; k < models.size();k++) {
typedef decltype(models[k].initial_states()) state_t;
typedef std::tuple<state_t, Mat<REAL_t>> decode_state_t;
// TODO(Jonathan): This is currently incorrect mathematically.
// To make it correct we need to replace valid_set by valid_set.Slice(1:)
// (which is not yet implemented)
auto initial_state = make_tuple<state_t, Mat<REAL_t>>(
models[k].initial_states(),
models[k].embedding[valid_set.data[0]]);
probs.emplace_back(sequence_probability::sequence_score<REAL_t, decode_state_t>(
valid_set,
initial_state,
[&](decode_state_t state) -> Mat<REAL_t> {
return MatOps<REAL_t>::softmax_rowwise(models[k].decode(
std::get<1>(state),
std::get<0>(state)
)).log();
},
[&](Mat<int> obs, decode_state_t state) -> decode_state_t {
return make_tuple<state_t, Mat<REAL_t>>(
models[k].activate(std::get<0>(state), obs).lstm_state,
models[k].embedding[obs]
);
},
1
));
// probs.emplace_back(FLAGS_use_surprise
// ? sequence_probability::sequence_surprises(
// models[k],
// valid_set.data,
// valid_set.code_lengths)
// : sequence_probability::sequence_probabilities(
// models[k],
// valid_set.data,
// valid_set.code_lengths));
}
for (size_t example_idx = 0; example_idx < valid_set.size(); ++example_idx) {
int best_model = -1;
double best_prob = (FLAGS_use_surprise ? 1.0 : -1.0) * std::numeric_limits<REAL_t>::infinity();
if (FLAGS_use_surprise) {
for (int k = 0; k < models.size();k++) {
auto prob = probs[k].w(example_idx);
if (prob < best_prob) {
best_prob = prob;
best_model = k;
}
}
} else {
for (int k = 0; k < models.size();k++) {
auto prob = probs[k].w(example_idx);
if (prob > best_prob) {
best_prob = prob;
best_model = k;
}
}
}
confusion.classified_a_when_b(best_model,
valid_set.target_for_example(example_idx));
if (best_model == valid_set.target_for_example(example_idx)) {
correct++;
}
}
seen_minibatches++;
total += valid_set.size();
journalist.tick(seen_minibatches, (REAL_t) 100.0 * correct / (REAL_t) total);
});
}
}
pool->wait_until_idle();
confusion.report();
return ((REAL_t) 100.0 * correct / (REAL_t) total);
};