当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R rsample vfold_cv V 折交叉验证


V-fold 交叉验证(也称为 k-fold 交叉验证)将数据随机分成大小大致相等的 V 组(称为 "folds")。分析数据的重采样由 V-1 个折叠组成,而评估集包含最终折叠。在基本的 V-fold 交叉验证(即无重复)中,重新采样的次数等于 V。

用法

vfold_cv(data, v = 10, repeats = 1, strata = NULL, breaks = 4, pool = 0.1, ...)

参数

data

一个 DataFrame 。

v

数据集的分区数。

repeats

重复 V-fold 分区的次数。

strata

data 中的变量(单个字符或名称)用于进行分层抽样。如果不是 NULL ,则每次重新采样都会在分层变量中创建。数字 strata 被分为四分位数。

breaks

给出对数值分层变量进行分层所需的箱数的单个数字。

pool

用于确定特定组是否太小的数据比例,是否应合并到另一个组中。我们不建议将此参数降低到默认值 0.1 以下,因为分层组太小存在危险。

...

这些点用于将来的扩展,并且必须为空。

带有类 vfold_cvrsettbl_dftbldata.frame 的 tibble。结果包括一列数据分割对象和一个或多个标识变量。对于单个重复,将有一个名为 id 的列,其中包含带有折叠标识符的字符串。对于重复,id 是重复编号,还有一个名为 id2 的附加列

包含折叠信息(重复内)。

细节

如果重复多次,每次都会进行基本的V-fold交叉验证。例如,如果 v = 10 使用 3 次重复,则总共有 30 个分割:分别生成 3 组,每组 10 个。

使用 strata 参数,在分层变量内进行随机抽样。这有助于确保重采样与原始数据集具有相同的比例。对于分类变量,采样是在每个类别内单独进行的。对于数字分层变量,strata 被分为四分位数,然后用于分层。低于总数10%的地层合并在一起;有关更多详细信息,请参阅make_strata()

例子

vfold_cv(mtcars, v = 10)
#> #  10-fold cross-validation 
#> # A tibble: 10 × 2
#>    splits         id    
#>    <list>         <chr> 
#>  1 <split [28/4]> Fold01
#>  2 <split [28/4]> Fold02
#>  3 <split [29/3]> Fold03
#>  4 <split [29/3]> Fold04
#>  5 <split [29/3]> Fold05
#>  6 <split [29/3]> Fold06
#>  7 <split [29/3]> Fold07
#>  8 <split [29/3]> Fold08
#>  9 <split [29/3]> Fold09
#> 10 <split [29/3]> Fold10
vfold_cv(mtcars, v = 10, repeats = 2)
#> #  10-fold cross-validation repeated 2 times 
#> # A tibble: 20 × 3
#>    splits         id      id2   
#>    <list>         <chr>   <chr> 
#>  1 <split [28/4]> Repeat1 Fold01
#>  2 <split [28/4]> Repeat1 Fold02
#>  3 <split [29/3]> Repeat1 Fold03
#>  4 <split [29/3]> Repeat1 Fold04
#>  5 <split [29/3]> Repeat1 Fold05
#>  6 <split [29/3]> Repeat1 Fold06
#>  7 <split [29/3]> Repeat1 Fold07
#>  8 <split [29/3]> Repeat1 Fold08
#>  9 <split [29/3]> Repeat1 Fold09
#> 10 <split [29/3]> Repeat1 Fold10
#> 11 <split [28/4]> Repeat2 Fold01
#> 12 <split [28/4]> Repeat2 Fold02
#> 13 <split [29/3]> Repeat2 Fold03
#> 14 <split [29/3]> Repeat2 Fold04
#> 15 <split [29/3]> Repeat2 Fold05
#> 16 <split [29/3]> Repeat2 Fold06
#> 17 <split [29/3]> Repeat2 Fold07
#> 18 <split [29/3]> Repeat2 Fold08
#> 19 <split [29/3]> Repeat2 Fold09
#> 20 <split [29/3]> Repeat2 Fold10

library(purrr)
data(wa_churn, package = "modeldata")

set.seed(13)
folds1 <- vfold_cv(wa_churn, v = 5)
map_dbl(
  folds1$splits,
  function(x) {
    dat <- as.data.frame(x)$churn
    mean(dat == "Yes")
  }
)
#> [1] 0.2649982 0.2660632 0.2609159 0.2679681 0.2669033

set.seed(13)
folds2 <- vfold_cv(wa_churn, strata = churn, v = 5)
map_dbl(
  folds2$splits,
  function(x) {
    dat <- as.data.frame(x)$churn
    mean(dat == "Yes")
  }
)
#> [1] 0.2653532 0.2653532 0.2653532 0.2653532 0.2654365

set.seed(13)
folds3 <- vfold_cv(wa_churn, strata = tenure, breaks = 6, v = 5)
map_dbl(
  folds3$splits,
  function(x) {
    dat <- as.data.frame(x)$churn
    mean(dat == "Yes")
  }
)
#> [1] 0.2656250 0.2661104 0.2652228 0.2638396 0.2660518
源代码:R/vfold.R

相关用法


注:本文由纯净天空筛选整理自Hannah Frick等大神的英文原创作品 V-Fold Cross-Validation。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。