当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R tune fit_resamples 通过重采样拟合多个模型


fit_resamples() 通过一个或多个重新采样计算一组性能指标。它不执行任何调整(请参阅 tune_grid()tune_bayes()),而是用于在许多重新采样中拟合单个模型+配方或模型+公式组合。

用法

fit_resamples(object, ...)

# S3 method for model_spec
fit_resamples(
  object,
  preprocessor,
  resamples,
  ...,
  metrics = NULL,
  control = control_resamples()
)

# S3 method for workflow
fit_resamples(
  object,
  resamples,
  ...,
  metrics = NULL,
  control = control_resamples()
)

参数

object

parsnip 模型规范或 workflows::workflow() 。不允许调整参数。

...

目前未使用。

preprocessor

使用 recipes::recipe() 创建的传统模型公式或配方。

resamples

rsample 函数(例如 rsample::vfold_cv() )创建的重采样 rset

metrics

yardstick::metric_set()NULL 用于计算一组标准指标。

control

用于微调重采样过程的 control_resamples() 对象。

性能指标

要使用您自己的性能指标,可以使用 yardstick::metric_set() 函数来选择每个模型应测量的内容。如果需要多个指标,可以将它们捆绑在一起。例如,要估计 ROC 曲线下的面积以及灵敏度和特异性(在典型概率截止值 0.50 下),可以给出 metrics 参数:


  metrics = metric_set(roc_auc, sens, spec)

每个指标都是针对每个候选模型计算的。

如果未提供指标集,则会创建一个指标集:

  • 对于回归模型,计算均方根误差和确定系数。

  • 对于分类,计算 ROC 曲线下的面积和总体准确度。

请注意,这些指标还决定了调整期间估计的预测类型。例如,在分类问题中,如果使用的度量全部与硬类预测相关,则不会创建分类概率。

这些指标的 out-of-sample 估计值包含在名为 .metrics 的列表列中。该小标题包含每个指标的行和值、估计器类型等的列。

collect_metrics() 可用于这些对象来折叠重采样的结果(以获得每个调整参数组合的最终重采样估计)。

获取预测

control_grid(save_pred = TRUE) 时,输出 tibble 包含一个名为 .predictions 的列表列,其中包含网格和每个折叠中每个参数组合的 out-of-sample 预测(可能非常大)。

tibble 的元素是 tibbles,其中包含调整参数的列、原始数据对象 ( .row ) 的行号、结果数据(与原始数据具有相同的名称)以及由的预测。例如,对于简单的回归问题,此函数会生成一个名为.pred 的列,依此类推。如上所述,返回的预测列由请求的度量类型确定。

此列表列可以是 unnested 使用 tidyr::unnest() 或使用便利函数 collect_predictions()

提取信息

extract 控制选项将导致返回一个名为 .extracts 的附加函数。这是一个列表列,其中包含每个调整参数组合的用户函数结果的标题。这可以允许返回在重采样期间创建的每个模型和/或配方对象。请注意,这可能会导致返回对象很大,具体取决于返回的内容。

控制函数包含一个选项 (extract),可用于保留在重采样中创建的任何模型或配方。该参数应该是具有单个参数的函数。每次重新采样中赋予函数的参数值是工作流对象(有关更多信息,请参阅workflows::workflow())。可以使用多个辅助函数轻松地从工作流程中提取预处理和/或模型信息,例如 extract_preprocessor()extract_fit_parsnip()

例如,如果有兴趣恢复每个防风草模型,可以使用:


  extract = function (x) extract_fit_parsnip(x)

请注意,赋予 extract 参数的函数是在每个适合的模型上评估的(而不是在评估的每个模型上)。如上所述,在某些情况下,可以针对 sub-models 导出模型预测,因此在这些情况下,并非调整参数网格中的每一行都有与其关联的单独 R 对象。

例子

library(recipes)
library(rsample)
library(parsnip)
library(workflows)

set.seed(6735)
folds <- vfold_cv(mtcars, v = 5)

spline_rec <- recipe(mpg ~ ., data = mtcars) %>%
  step_ns(disp) %>%
  step_ns(wt)

lin_mod <- linear_reg() %>%
  set_engine("lm")

control <- control_resamples(save_pred = TRUE)

spline_res <- fit_resamples(lin_mod, spline_rec, folds, control = control)

spline_res
#> # Resampling results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 5
#>   splits         id    .metrics         .notes           .predictions    
#>   <list>         <chr> <list>           <list>           <list>          
#> 1 <split [25/7]> Fold1 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [7 × 4]>
#> 2 <split [25/7]> Fold2 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [7 × 4]>
#> 3 <split [26/6]> Fold3 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [6 × 4]>
#> 4 <split [26/6]> Fold4 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [6 × 4]>
#> 5 <split [26/6]> Fold5 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [6 × 4]>

show_best(spline_res, metric = "rmse")
#> # A tibble: 1 × 6
#>   .metric .estimator  mean     n std_err .config             
#>   <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 rmse    standard    3.11     5   0.168 Preprocessor1_Model1

# You can also wrap up a preprocessor and a model into a workflow, and
# supply that to `fit_resamples()` instead. Here, a workflows "variables"
# preprocessor is used, which lets you supply terms using dplyr selectors.
# The variables are used as-is, no preprocessing is done to them.
wf <- workflow() %>%
  add_variables(outcomes = mpg, predictors = everything()) %>%
  add_model(lin_mod)

wf_res <- fit_resamples(wf, folds)
源代码:R/resample.R

相关用法


注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Fit multiple models via resampling。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。