当前位置: 首页>>代码示例>>Python>>正文


Python exceptions.UndefinedMetricWarning方法代码示例

本文整理汇总了Python中sklearn.exceptions.UndefinedMetricWarning方法的典型用法代码示例。如果您正苦于以下问题:Python exceptions.UndefinedMetricWarning方法的具体用法?Python exceptions.UndefinedMetricWarning怎么用?Python exceptions.UndefinedMetricWarning使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在sklearn.exceptions的用法示例。


在下文中一共展示了exceptions.UndefinedMetricWarning方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_average_binary_jaccard_score

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test_average_binary_jaccard_score(recwarn):
    # tp=0, fp=0, fn=1, tn=0
    assert jaccard_score([1], [0], average='binary') == 0.
    # tp=0, fp=0, fn=0, tn=1
    msg = ('Jaccard is ill-defined and being set to 0.0 due to '
           'no true or predicted samples')
    assert assert_warns_message(UndefinedMetricWarning,
                                msg,
                                jaccard_score,
                                [0, 0], [0, 0],
                                average='binary') == 0.
    # tp=1, fp=0, fn=0, tn=0 (pos_label=0)
    assert jaccard_score([0], [0], pos_label=0,
                         average='binary') == 1.
    y_true = np.array([1, 0, 1, 1, 0])
    y_pred = np.array([1, 0, 1, 1, 1])
    assert_almost_equal(jaccard_score(y_true, y_pred,
                                      average='binary'), 3. / 4)
    assert_almost_equal(jaccard_score(y_true, y_pred,
                                      average='binary',
                                      pos_label=0), 1. / 2)

    assert not list(recwarn) 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:25,代码来源:test_classification.py

示例2: test_precision_recall_f1_no_labels

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test_precision_recall_f1_no_labels(beta, average):
    y_true = np.zeros((20, 3))
    y_pred = np.zeros_like(y_true)

    p, r, f, s = assert_warns(UndefinedMetricWarning,
                              precision_recall_fscore_support,
                              y_true, y_pred, average=average,
                              beta=beta)
    assert_almost_equal(p, 0)
    assert_almost_equal(r, 0)
    assert_almost_equal(f, 0)
    assert_equal(s, None)

    fbeta = assert_warns(UndefinedMetricWarning, fbeta_score,
                         y_true, y_pred,
                         beta=beta, average=average)
    assert_almost_equal(fbeta, 0) 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:19,代码来源:test_classification.py

示例3: test_precision_recall_f1_no_labels_average_none

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test_precision_recall_f1_no_labels_average_none():
    y_true = np.zeros((20, 3))
    y_pred = np.zeros_like(y_true)

    beta = 1

    # tp = [0, 0, 0]
    # fn = [0, 0, 0]
    # fp = [0, 0, 0]
    # support = [0, 0, 0]
    # |y_hat_i inter y_i | = [0, 0, 0]
    # |y_i| = [0, 0, 0]
    # |y_hat_i| = [0, 0, 0]

    p, r, f, s = assert_warns(UndefinedMetricWarning,
                              precision_recall_fscore_support,
                              y_true, y_pred, average=None, beta=beta)
    assert_array_almost_equal(p, [0, 0, 0], 2)
    assert_array_almost_equal(r, [0, 0, 0], 2)
    assert_array_almost_equal(f, [0, 0, 0], 2)
    assert_array_almost_equal(s, [0, 0, 0], 2)

    fbeta = assert_warns(UndefinedMetricWarning, fbeta_score,
                         y_true, y_pred, beta=beta, average=None)
    assert_array_almost_equal(fbeta, [0, 0, 0], 2) 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:27,代码来源:test_classification.py

示例4: test_roc_curve_one_label

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test_roc_curve_one_label():
    y_true = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    y_pred = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
    # assert there are warnings
    w = UndefinedMetricWarning
    fpr, tpr, thresholds = assert_warns(w, roc_curve, y_true, y_pred)
    # all true labels, all fpr should be nan
    assert_array_equal(fpr, np.full(len(thresholds), np.nan))
    assert_equal(fpr.shape, tpr.shape)
    assert_equal(fpr.shape, thresholds.shape)

    # assert there are warnings
    fpr, tpr, thresholds = assert_warns(w, roc_curve,
                                        [1 - x for x in y_true],
                                        y_pred)
    # all negative labels, all tpr should be nan
    assert_array_equal(tpr, np.full(len(thresholds), np.nan))
    assert_equal(fpr.shape, tpr.shape)
    assert_equal(fpr.shape, thresholds.shape) 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:21,代码来源:test_ranking.py

示例5: specificity_score

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def specificity_score(y_true, y_pred, pos_label=1, sample_weight=None):
    """Compute the specificity or true negative rate.

    Args:
        y_true (array-like): Ground truth (correct) target values.
        y_pred (array-like): Estimated targets as returned by a classifier.
        pos_label (scalar, optional): The label of the positive class.
        sample_weight (array-like, optional): Sample weights.
    """
    MCM = multilabel_confusion_matrix(y_true, y_pred, labels=[pos_label],
                                      sample_weight=sample_weight)
    tn, fp, fn, tp = MCM.ravel()
    negs = tn + fp
    if negs == 0:
        warnings.warn('specificity_score is ill-defined and being set to 0.0 '
                      'due to no negative samples.', UndefinedMetricWarning)
        return 0.
    return tn / negs 
开发者ID:IBM,项目名称:AIF360,代码行数:20,代码来源:metrics.py

示例6: test_roc_curve_one_label

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test_roc_curve_one_label():
    y_true = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    y_pred = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
    # assert there are warnings
    w = UndefinedMetricWarning
    fpr, tpr, thresholds = assert_warns(w, roc_curve, y_true, y_pred)
    # all true labels, all fpr should be nan
    assert_array_equal(fpr,
                       np.nan * np.ones(len(thresholds)))
    assert_equal(fpr.shape, tpr.shape)
    assert_equal(fpr.shape, thresholds.shape)

    # assert there are warnings
    fpr, tpr, thresholds = assert_warns(w, roc_curve,
                                        [1 - x for x in y_true],
                                        y_pred)
    # all negative labels, all tpr should be nan
    assert_array_equal(tpr,
                       np.nan * np.ones(len(thresholds)))
    assert_equal(fpr.shape, tpr.shape)
    assert_equal(fpr.shape, thresholds.shape) 
开发者ID:alvarobartt,项目名称:twitter-stock-recommendation,代码行数:23,代码来源:test_ranking.py

示例7: generalized_fpr

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def generalized_fpr(y_true, probas_pred, pos_label=1, sample_weight=None):
    r"""Return the ratio of generalized false positives to negative examples in
    the dataset, :math:`GFPR = \tfrac{GFP}{N}`.

    Generalized confusion matrix measures such as this are calculated by summing
    the probabilities of the positive class instead of the hard predictions.

    Args:
        y_true (array-like): Ground-truth (correct) target values.
        probas_pred (array-like): Probability estimates of the positive class.
        pos_label (scalar, optional): The label of the positive class.
        sample_weight (array-like, optional): Sample weights.

    Returns:
        float: Generalized false positive rate. If there are no negative samples
        in y_true, this will raise an
        :class:`~sklearn.exceptions.UndefinedMetricWarning` and return 0.
    """
    idx = (y_true != pos_label)
    if not np.any(idx):
        warnings.warn("generalized_fpr is ill-defined because there are no "
                      "negative samples in y_true.", UndefinedMetricWarning)
        return 0.
    if sample_weight is None:
        return probas_pred[idx].mean()
    return np.average(probas_pred[idx], weights=sample_weight[idx]) 
开发者ID:IBM,项目名称:AIF360,代码行数:28,代码来源:metrics.py

示例8: generalized_fnr

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def generalized_fnr(y_true, probas_pred, pos_label=1, sample_weight=None):
    r"""Return the ratio of generalized false negatives to positive examples in
    the dataset, :math:`GFNR = \tfrac{GFN}{P}`.

    Generalized confusion matrix measures such as this are calculated by summing
    the probabilities of the positive class instead of the hard predictions.

    Args:
        y_true (array-like): Ground-truth (correct) target values.
        probas_pred (array-like): Probability estimates of the positive class.
        pos_label (scalar, optional): The label of the positive class.
        sample_weight (array-like, optional): Sample weights.

    Returns:
        float: Generalized false negative rate. If there are no positive samples
        in y_true, this will raise an
        :class:`~sklearn.exceptions.UndefinedMetricWarning` and return 0.
    """
    idx = (y_true == pos_label)
    if not np.any(idx):
        warnings.warn("generalized_fnr is ill-defined because there are no "
                      "positive samples in y_true.", UndefinedMetricWarning)
        return 0.
    if sample_weight is None:
        return 1 - probas_pred[idx].mean()
    return 1 - np.average(probas_pred[idx], weights=sample_weight[idx])


# ============================ GROUP FAIRNESS ================================== 
开发者ID:IBM,项目名称:AIF360,代码行数:31,代码来源:metrics.py

示例9: test_precision_recall_f1_no_labels

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test_precision_recall_f1_no_labels():
    y_true = np.zeros((20, 3))
    y_pred = np.zeros_like(y_true)

    # tp = [0, 0, 0]
    # fn = [0, 0, 0]
    # fp = [0, 0, 0]
    # support = [0, 0, 0]
    # |y_hat_i inter y_i | = [0, 0, 0]
    # |y_i| = [0, 0, 0]
    # |y_hat_i| = [0, 0, 0]

    for beta in [1]:
        p, r, f, s = assert_warns(UndefinedMetricWarning,
                                  precision_recall_fscore_support,
                                  y_true, y_pred, average=None, beta=beta)
        assert_array_almost_equal(p, [0, 0, 0], 2)
        assert_array_almost_equal(r, [0, 0, 0], 2)
        assert_array_almost_equal(f, [0, 0, 0], 2)
        assert_array_almost_equal(s, [0, 0, 0], 2)

        fbeta = assert_warns(UndefinedMetricWarning, fbeta_score,
                             y_true, y_pred, beta=beta, average=None)
        assert_array_almost_equal(fbeta, [0, 0, 0], 2)

        for average in ["macro", "micro", "weighted", "samples"]:
            p, r, f, s = assert_warns(UndefinedMetricWarning,
                                      precision_recall_fscore_support,
                                      y_true, y_pred, average=average,
                                      beta=beta)
            assert_almost_equal(p, 0)
            assert_almost_equal(r, 0)
            assert_almost_equal(f, 0)
            assert_equal(s, None)

            fbeta = assert_warns(UndefinedMetricWarning, fbeta_score,
                                 y_true, y_pred,
                                 beta=beta, average=average)
            assert_almost_equal(fbeta, 0) 
开发者ID:alvarobartt,项目名称:twitter-stock-recommendation,代码行数:41,代码来源:test_classification.py

示例10: test_prf_warnings

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test_prf_warnings():
    # average of per-label scores
    f, w = precision_recall_fscore_support, UndefinedMetricWarning
    my_assert = assert_warns_message
    for average in [None, 'weighted', 'macro']:
        msg = ('Precision and F-score are ill-defined and '
               'being set to 0.0 in labels with no predicted samples.')
        my_assert(w, msg, f, [0, 1, 2], [1, 1, 2], average=average)

        msg = ('Recall and F-score are ill-defined and '
               'being set to 0.0 in labels with no true samples.')
        my_assert(w, msg, f, [1, 1, 2], [0, 1, 2], average=average)

    # average of per-sample scores
    msg = ('Precision and F-score are ill-defined and '
           'being set to 0.0 in samples with no predicted labels.')
    my_assert(w, msg, f, np.array([[1, 0], [1, 0]]),
              np.array([[1, 0], [0, 0]]), average='samples')

    msg = ('Recall and F-score are ill-defined and '
           'being set to 0.0 in samples with no true labels.')
    my_assert(w, msg, f, np.array([[1, 0], [0, 0]]),
              np.array([[1, 0], [1, 0]]),
              average='samples')

    # single score: micro-average
    msg = ('Precision and F-score are ill-defined and '
           'being set to 0.0 due to no predicted samples.')
    my_assert(w, msg, f, np.array([[1, 1], [1, 1]]),
              np.array([[0, 0], [0, 0]]), average='micro')

    msg = ('Recall and F-score are ill-defined and '
           'being set to 0.0 due to no true samples.')
    my_assert(w, msg, f, np.array([[0, 0], [0, 0]]),
              np.array([[1, 1], [1, 1]]), average='micro')

    # single positive label
    msg = ('Precision and F-score are ill-defined and '
           'being set to 0.0 due to no predicted samples.')
    my_assert(w, msg, f, [1, 1], [-1, -1], average='binary')

    msg = ('Recall and F-score are ill-defined and '
           'being set to 0.0 due to no true samples.')
    my_assert(w, msg, f, [-1, -1], [1, 1], average='binary')

    clean_warning_registry()
    with warnings.catch_warnings(record=True) as record:
        warnings.simplefilter('always')
        precision_recall_fscore_support([0, 0], [0, 0], average="binary")
        msg = ('Recall and F-score are ill-defined and '
               'being set to 0.0 due to no true samples.')
        assert_equal(str(record.pop().message), msg)
        msg = ('Precision and F-score are ill-defined and '
               'being set to 0.0 due to no predicted samples.')
        assert_equal(str(record.pop().message), msg) 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:57,代码来源:test_classification.py

示例11: collect_story_predictions

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def collect_story_predictions(
    completed_trackers: List['DialogueStateTracker'],
    agent: 'Agent',
    fail_on_prediction_errors: bool = False,
    use_e2e: bool = False
) -> Tuple[StoryEvalution, int]:
    """Test the stories from a file, running them through the stored model."""
    from rasa_nlu.test import get_evaluation_metrics
    from tqdm import tqdm

    story_eval_store = EvaluationStore()
    failed = []
    correct_dialogues = []
    num_stories = len(completed_trackers)

    logger.info("Evaluating {} stories\n"
                "Progress:".format(num_stories))

    action_list = []

    for tracker in tqdm(completed_trackers):
        tracker_results, predicted_tracker, tracker_actions = \
            _predict_tracker_actions(tracker, agent,
                                     fail_on_prediction_errors, use_e2e)

        story_eval_store.merge_store(tracker_results)

        action_list.extend(tracker_actions)

        if tracker_results.has_prediction_target_mismatch():
            # there is at least one wrong prediction
            failed.append(predicted_tracker)
            correct_dialogues.append(0)
        else:
            correct_dialogues.append(1)

    logger.info("Finished collecting predictions.")
    with warnings.catch_warnings():
        from sklearn.exceptions import UndefinedMetricWarning

        warnings.simplefilter("ignore", UndefinedMetricWarning)
        report, precision, f1, accuracy = get_evaluation_metrics(
            [1] * len(completed_trackers), correct_dialogues)

    in_training_data_fraction = _in_training_data_fraction(action_list)

    log_evaluation_table([1] * len(completed_trackers),
                         "END-TO-END" if use_e2e else "CONVERSATION",
                         report, precision, f1, accuracy,
                         in_training_data_fraction,
                         include_report=False)

    return (StoryEvalution(evaluation_store=story_eval_store,
                           failed_stories=failed,
                           action_list=action_list,
                           in_training_data_fraction=in_training_data_fraction),
            num_stories) 
开发者ID:RasaHQ,项目名称:rasa_core,代码行数:59,代码来源:test.py

示例12: test

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test(stories: Text,
               agent: 'Agent',
               max_stories: Optional[int] = None,
               out_directory: Optional[Text] = None,
               fail_on_prediction_errors: bool = False,
               use_e2e: bool = False):
    """Run the evaluation of the stories, optionally plot the results."""
    from rasa_nlu.test import get_evaluation_metrics

    completed_trackers = await _generate_trackers(stories, agent,
                                                  max_stories, use_e2e)

    story_evaluation, _ = collect_story_predictions(completed_trackers, agent,
                                                    fail_on_prediction_errors,
                                                    use_e2e)

    evaluation_store = story_evaluation.evaluation_store

    with warnings.catch_warnings():
        from sklearn.exceptions import UndefinedMetricWarning

        warnings.simplefilter("ignore", UndefinedMetricWarning)
        report, precision, f1, accuracy = get_evaluation_metrics(
            evaluation_store.serialise_targets(),
            evaluation_store.serialise_predictions()
        )

    if out_directory:
        plot_story_evaluation(evaluation_store.action_targets,
                              evaluation_store.action_predictions,
                              report, precision, f1, accuracy,
                              story_evaluation.in_training_data_fraction,
                              out_directory)

    log_failed_stories(story_evaluation.failed_stories, out_directory)

    return {
        "report": report,
        "precision": precision,
        "f1": f1,
        "accuracy": accuracy,
        "actions": story_evaluation.action_list,
        "in_training_data_fraction":
            story_evaluation.in_training_data_fraction,
        "is_end_to_end_evaluation": use_e2e
    } 
开发者ID:RasaHQ,项目名称:rasa_core,代码行数:48,代码来源:test.py

示例13: test

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test(
    stories: Text,
    agent: "Agent",
    max_stories: Optional[int] = None,
    out_directory: Optional[Text] = None,
    fail_on_prediction_errors: bool = False,
    e2e: bool = False,
    disable_plotting: bool = False,
):
    """Run the evaluation of the stories, optionally plot the results."""
    from rasa.nlu.test import get_evaluation_metrics

    completed_trackers = await _generate_trackers(stories, agent, max_stories, e2e)

    story_evaluation, _ = collect_story_predictions(
        completed_trackers, agent, fail_on_prediction_errors, e2e
    )

    evaluation_store = story_evaluation.evaluation_store

    with warnings.catch_warnings():
        from sklearn.exceptions import UndefinedMetricWarning

        warnings.simplefilter("ignore", UndefinedMetricWarning)

        targets, predictions = evaluation_store.serialise()
        report, precision, f1, accuracy = get_evaluation_metrics(targets, predictions)

    if out_directory:
        plot_story_evaluation(
            evaluation_store.action_targets,
            evaluation_store.action_predictions,
            report,
            precision,
            f1,
            accuracy,
            story_evaluation.in_training_data_fraction,
            out_directory,
            disable_plotting,
        )

    log_failed_stories(story_evaluation.failed_stories, out_directory)

    return {
        "report": report,
        "precision": precision,
        "f1": f1,
        "accuracy": accuracy,
        "actions": story_evaluation.action_list,
        "in_training_data_fraction": story_evaluation.in_training_data_fraction,
        "is_end_to_end_evaluation": e2e,
    } 
开发者ID:botfront,项目名称:rasa-for-botfront,代码行数:54,代码来源:test.py

示例14: test_multilabel_input_NC

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test_multilabel_input_NC():
    def _test(average):
        re = Recall(average=average, is_multilabel=True)

        y_pred = torch.randint(0, 2, size=(20, 5))
        y = torch.randint(0, 2, size=(20, 5)).long()
        re.update((y_pred, y))
        np_y_pred = to_numpy_multilabel(y_pred)
        np_y = to_numpy_multilabel(y)
        assert re._type == "multilabel"
        re_compute = re.compute() if average else re.compute().mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert recall_score(np_y, np_y_pred, average="samples") == pytest.approx(re_compute)

        re.reset()
        y_pred = torch.randint(0, 2, size=(10, 4))
        y = torch.randint(0, 2, size=(10, 4)).long()
        re.update((y_pred, y))
        np_y_pred = y_pred.numpy()
        np_y = y.numpy()
        assert re._type == "multilabel"
        re_compute = re.compute() if average else re.compute().mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert recall_score(np_y, np_y_pred, average="samples") == pytest.approx(re_compute)

        # Batched Updates
        re.reset()
        y_pred = torch.randint(0, 2, size=(100, 4))
        y = torch.randint(0, 2, size=(100, 4)).long()

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            re.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

        np_y = y.numpy()
        np_y_pred = y_pred.numpy()
        assert re._type == "multilabel"
        re_compute = re.compute() if average else re.compute().mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert recall_score(np_y, np_y_pred, average="samples") == pytest.approx(re_compute)

    for _ in range(5):
        _test(average=True)
        _test(average=False)

    re1 = Recall(is_multilabel=True, average=True)
    re2 = Recall(is_multilabel=True, average=False)
    y_pred = torch.randint(0, 2, size=(10, 4))
    y = torch.randint(0, 2, size=(10, 4)).long()
    re1.update((y_pred, y))
    re2.update((y_pred, y))
    assert re1.compute() == pytest.approx(re2.compute().mean().item()) 
开发者ID:pytorch,项目名称:ignite,代码行数:60,代码来源:test_recall.py

示例15: test_multilabel_input_NCL

# 需要导入模块: from sklearn import exceptions [as 别名]
# 或者: from sklearn.exceptions import UndefinedMetricWarning [as 别名]
def test_multilabel_input_NCL():
    def _test(average):
        re = Recall(average=average, is_multilabel=True)

        y_pred = torch.randint(0, 2, size=(10, 5, 10))
        y = torch.randint(0, 2, size=(10, 5, 10)).long()
        re.update((y_pred, y))
        np_y_pred = to_numpy_multilabel(y_pred)
        np_y = to_numpy_multilabel(y)
        assert re._type == "multilabel"
        re_compute = re.compute() if average else re.compute().mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert recall_score(np_y, np_y_pred, average="samples") == pytest.approx(re_compute)

        re.reset()
        y_pred = torch.randint(0, 2, size=(15, 4, 10))
        y = torch.randint(0, 2, size=(15, 4, 10)).long()
        re.update((y_pred, y))
        np_y_pred = to_numpy_multilabel(y_pred)
        np_y = to_numpy_multilabel(y)
        assert re._type == "multilabel"
        re_compute = re.compute() if average else re.compute().mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert recall_score(np_y, np_y_pred, average="samples") == pytest.approx(re_compute)

        # Batched Updates
        re.reset()
        y_pred = torch.randint(0, 2, size=(100, 4, 12))
        y = torch.randint(0, 2, size=(100, 4, 12)).long()

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            re.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))

        np_y = to_numpy_multilabel(y)
        np_y_pred = to_numpy_multilabel(y_pred)
        assert re._type == "multilabel"
        re_compute = re.compute() if average else re.compute().mean().item()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UndefinedMetricWarning)
            assert recall_score(np_y, np_y_pred, average="samples") == pytest.approx(re_compute)

    for _ in range(5):
        _test(average=True)
        _test(average=False)

    re1 = Recall(is_multilabel=True, average=True)
    re2 = Recall(is_multilabel=True, average=False)
    y_pred = torch.randint(0, 2, size=(10, 4, 20))
    y = torch.randint(0, 2, size=(10, 4, 20)).long()
    re1.update((y_pred, y))
    re2.update((y_pred, y))
    assert re1.compute() == pytest.approx(re2.compute().mean().item()) 
开发者ID:pytorch,项目名称:ignite,代码行数:60,代码来源:test_recall.py


注:本文中的sklearn.exceptions.UndefinedMetricWarning方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。