当前位置: 首页>>代码示例>>C++>>正文


C++ utils::ConfusionMatrix方法代码示例

本文整理汇总了C++中utils::ConfusionMatrix方法的典型用法代码示例。如果您正苦于以下问题:C++ utils::ConfusionMatrix方法的具体用法?C++ utils::ConfusionMatrix怎么用?C++ utils::ConfusionMatrix使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils的用法示例。


在下文中一共展示了utils::ConfusionMatrix方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。

示例1: average_error

REAL_t average_error(const vector<model_t>& models, const vector<vector<Databatch>>& validation_sets) {
    atomic<int> correct(0);
    int set_size = 10;
    ReportProgress<double> journalist("Average error", validation_sets.size() * set_size);
    atomic<int> total(0);
    atomic<int> seen_minibatches(0);
    auto confusion = ConfusionMatrix(5, SST::label_names);

    for (int validation_set_num = 0; validation_set_num < validation_sets.size(); validation_set_num++) {

        auto random_batch_order = utils::random_arange(validation_sets[validation_set_num].size());
        if (random_batch_order.size() > set_size)
            random_batch_order.resize(set_size);

        for (auto& minibatch_num : random_batch_order) {
            pool->run([&confusion, &journalist, &models, &correct,&seen_minibatches,  &total, &validation_sets, validation_set_num, minibatch_num] {
                auto& valid_set = validation_sets[validation_set_num][minibatch_num];
                vector<mat> probs;
                for (int k = 0; k < models.size();k++) {

                    typedef decltype(models[k].initial_states()) state_t;
                    typedef std::tuple<state_t, Mat<REAL_t>> decode_state_t;


                    // TODO(Jonathan): This is currently incorrect mathematically.
                    // To make it correct we need to replace valid_set by valid_set.Slice(1:)
                    // (which is not yet implemented)

                    auto initial_state = make_tuple<state_t, Mat<REAL_t>>(
                            models[k].initial_states(),
                            models[k].embedding[valid_set.data[0]]);

                    probs.emplace_back(sequence_probability::sequence_score<REAL_t, decode_state_t>(
                        valid_set,
                        initial_state,
                        [&](decode_state_t state) -> Mat<REAL_t> {
                            return MatOps<REAL_t>::softmax_rowwise(models[k].decode(
                                std::get<1>(state),
                                std::get<0>(state)
                            )).log();
                        },
                        [&](Mat<int> obs, decode_state_t state) -> decode_state_t {
                            return make_tuple<state_t, Mat<REAL_t>>(
                                models[k].activate(std::get<0>(state), obs).lstm_state,
                                models[k].embedding[obs]
                            );
                        },
                        1
                    ));
                    // probs.emplace_back(FLAGS_use_surprise
                    //     ? sequence_probability::sequence_surprises(
                    //         models[k],
                    //         valid_set.data,
                    //         valid_set.code_lengths)
                    //     : sequence_probability::sequence_probabilities(
                    //         models[k],
                    //         valid_set.data,
                    //         valid_set.code_lengths));
                }

                for (size_t example_idx = 0; example_idx < valid_set.size(); ++example_idx) {
                    int best_model = -1;
                    double best_prob = (FLAGS_use_surprise ? 1.0 : -1.0) * std::numeric_limits<REAL_t>::infinity();
                    if (FLAGS_use_surprise) {
                        for (int k = 0; k < models.size();k++) {
                            auto prob = probs[k].w(example_idx);
                            if (prob < best_prob) {
                                best_prob = prob;
                                best_model = k;
                            }
                        }
                    } else {
                        for (int k = 0; k < models.size();k++) {
                            auto prob = probs[k].w(example_idx);
                            if (prob > best_prob) {
                                best_prob = prob;
                                best_model = k;
                            }
                        }
                    }
                    confusion.classified_a_when_b(best_model,
                            valid_set.target_for_example(example_idx));
                    if (best_model == valid_set.target_for_example(example_idx)) {
                        correct++;
                    }
                }
                seen_minibatches++;
                total += valid_set.size();
                journalist.tick(seen_minibatches, (REAL_t) 100.0 * correct / (REAL_t) total);
            });
        }
    }
    pool->wait_until_idle();
    confusion.report();

    return ((REAL_t) 100.0 * correct / (REAL_t) total);
};
开发者ID:codeaudit,项目名称:Dali,代码行数:97,代码来源:language_model_from_senti.cpp


注:本文中的utils::ConfusionMatrix方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。