当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R yardstick classification_cost 不良分类的成本函数


classification_cost() 根据用户定义的成本计算不良预测的成本。将成本乘以估计的类别概率并返回平均成本。

用法

classification_cost(data, ...)

# S3 method for data.frame
classification_cost(
  data,
  truth,
  ...,
  costs = NULL,
  na_rm = TRUE,
  event_level = yardstick_event_level(),
  case_weights = NULL
)

classification_cost_vec(
  truth,
  estimate,
  costs = NULL,
  na_rm = TRUE,
  event_level = yardstick_event_level(),
  case_weights = NULL,
  ...
)

参数

data

包含 truth... 指定的列的 data.frame

...

一组不带引号的列名称或一个或多个 dplyr 选择器函数,用于选择哪些变量包含类概率。如果 truth 是二进制,则仅应选择 1 列,并且它应对应于 event_level 的值。否则,列的数量应与 truth 的因子级别一样多,并且列的顺序应与 truth 的因子级别相同。

truth

真实类结果的列标识符(即 factor )。这应该是一个不带引号的列名,尽管此参数是通过表达式传递的并且支持quasiquotation(您可以不带引号的列名)。对于 _vec() 函数,一个 factor 向量。

costs

具有列 "truth""estimate""cost" 的 DataFrame 。

"truth""estimate" 应该是包含 truth 因子级别的唯一组合的字符列。

"costs" 应该是一个数字列,表示预测 "estimate" 时应应用的成本,但真实结果是 "truth"

通常情况下,当 "truth" == "estimate" 时,成本为零(正确预测不会受到惩罚)。

如果 truth 级别的任何组合丢失,则假定它们的成本为零。

如果 NULL ,则使用相等的成本,将 0 的成本应用于正确的预测,并将 1 的成本应用于错误的预测。

na_rm

logical 值,指示在计算继续之前是否应剥离 NA 值。

event_level

单个字符串。 "first""second" 指定将truth 的哪个级别视为"event"。此参数仅适用于 estimator = "binary" 。默认使用内部帮助程序,通常默认为 "first" ,但是,如果设置了已弃用的全局选项 yardstick.event_first ,则将使用该帮助程序并发出警告。

case_weights

案例权重的可选列标识符。这应该是一个不带引号的列名称,其计算结果为 data 中的数字列。对于 _vec() 函数,一个数值向量。

estimate

如果truth是二进制的,对应于 "relevant" 类的类概率的数值向量。否则,矩阵的列数与因子级别一样多truth.假设它们的顺序与 truth 的级别相同。

tibble 包含列 .metric.estimator.estimate 以及 1 行值。

对于分组 DataFrame ,返回的行数将与组数相同。

对于 class_cost_vec() ,单个 numeric 值(或 NA )。

细节

例如,假设存在三个类: "A""B""C" 。假设存在真正的 "A" 观察,其类概率为 A = 0.3 / B = 0.3 / C = 0.4 。假设,当真实结果是类 "A" 时,每个类的成本为 A = 0 / B = 5 / C = 10 ,错误预测 "C" 的概率比预测 "B" 受到的惩罚更大。此预测的成本为 0.3 * 0 + 0.3 * 5 + 0.4 * 10 。该计算针对每个样本进行,并对各个成本进行平均。

也可以看看

其他类概率指标:average_precision() , brier_class() , gain_capture() , mn_log_loss() , pr_auc() , roc_auc() , roc_aunp() , roc_aunu()

作者

马克斯·库恩

例子

library(dplyr)

# ---------------------------------------------------------------------------
# Two class example
data(two_class_example)

# Assuming `Class1` is our "event", this penalizes false positives heavily
costs1 <- tribble(
  ~truth,   ~estimate, ~cost,
  "Class1", "Class2",  1,
  "Class2", "Class1",  2
)

# Assuming `Class1` is our "event", this penalizes false negatives heavily
costs2 <- tribble(
  ~truth,   ~estimate, ~cost,
  "Class1", "Class2",  2,
  "Class2", "Class1",  1
)

classification_cost(two_class_example, truth, Class1, costs = costs1)
#> # A tibble: 1 × 3
#>   .metric             .estimator .estimate
#>   <chr>               <chr>          <dbl>
#> 1 classification_cost binary         0.288

classification_cost(two_class_example, truth, Class1, costs = costs2)
#> # A tibble: 1 × 3
#>   .metric             .estimator .estimate
#>   <chr>               <chr>          <dbl>
#> 1 classification_cost binary         0.260

# ---------------------------------------------------------------------------
# Multiclass
data(hpc_cv)

# Define cost matrix from Kuhn and Johnson (2013)
hpc_costs <- tribble(
  ~estimate, ~truth, ~cost,
  "VF",      "VF",    0,
  "VF",      "F",     1,
  "VF",      "M",     5,
  "VF",      "L",    10,
  "F",       "VF",    1,
  "F",       "F",     0,
  "F",       "M",     5,
  "F",       "L",     5,
  "M",       "VF",    1,
  "M",       "F",     1,
  "M",       "M",     0,
  "M",       "L",     1,
  "L",       "VF",    1,
  "L",       "F",     1,
  "L",       "M",     1,
  "L",       "L",     0
)

# You can use the col1:colN tidyselect syntax
hpc_cv %>%
  filter(Resample == "Fold01") %>%
  classification_cost(obs, VF:L, costs = hpc_costs)
#> # A tibble: 1 × 3
#>   .metric             .estimator .estimate
#>   <chr>               <chr>          <dbl>
#> 1 classification_cost multiclass     0.779

# Groups are respected
hpc_cv %>%
  group_by(Resample) %>%
  classification_cost(obs, VF:L, costs = hpc_costs)
#> # A tibble: 10 × 4
#>    Resample .metric             .estimator .estimate
#>    <chr>    <chr>               <chr>          <dbl>
#>  1 Fold01   classification_cost multiclass     0.779
#>  2 Fold02   classification_cost multiclass     0.735
#>  3 Fold03   classification_cost multiclass     0.654
#>  4 Fold04   classification_cost multiclass     0.754
#>  5 Fold05   classification_cost multiclass     0.777
#>  6 Fold06   classification_cost multiclass     0.737
#>  7 Fold07   classification_cost multiclass     0.743
#>  8 Fold08   classification_cost multiclass     0.749
#>  9 Fold09   classification_cost multiclass     0.760
#> 10 Fold10   classification_cost multiclass     0.771

相关用法


注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Costs function for poor classification。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。