當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


R tune fit_resamples 通過重采樣擬合多個模型


fit_resamples() 通過一個或多個重新采樣計算一組性能指標。它不執行任何調整(請參閱 tune_grid()tune_bayes()),而是用於在許多重新采樣中擬合單個模型+配方或模型+公式組合。

用法

fit_resamples(object, ...)

# S3 method for model_spec
fit_resamples(
  object,
  preprocessor,
  resamples,
  ...,
  metrics = NULL,
  control = control_resamples()
)

# S3 method for workflow
fit_resamples(
  object,
  resamples,
  ...,
  metrics = NULL,
  control = control_resamples()
)

參數

object

parsnip 模型規範或 workflows::workflow() 。不允許調整參數。

...

目前未使用。

preprocessor

使用 recipes::recipe() 創建的傳統模型公式或配方。

resamples

rsample 函數(例如 rsample::vfold_cv() )創建的重采樣 rset

metrics

yardstick::metric_set()NULL 用於計算一組標準指標。

control

用於微調重采樣過程的 control_resamples() 對象。

性能指標

要使用您自己的性能指標,可以使用 yardstick::metric_set() 函數來選擇每個模型應測量的內容。如果需要多個指標,可以將它們捆綁在一起。例如,要估計 ROC 曲線下的麵積以及靈敏度和特異性(在典型概率截止值 0.50 下),可以給出 metrics 參數:


  metrics = metric_set(roc_auc, sens, spec)

每個指標都是針對每個候選模型計算的。

如果未提供指標集,則會創建一個指標集:

  • 對於回歸模型,計算均方根誤差和確定係數。

  • 對於分類,計算 ROC 曲線下的麵積和總體準確度。

請注意,這些指標還決定了調整期間估計的預測類型。例如,在分類問題中,如果使用的度量全部與硬類預測相關,則不會創建分類概率。

這些指標的 out-of-sample 估計值包含在名為 .metrics 的列表列中。該小標題包含每個指標的行和值、估計器類型等的列。

collect_metrics() 可用於這些對象來折疊重采樣的結果(以獲得每個調整參數組合的最終重采樣估計)。

獲取預測

control_grid(save_pred = TRUE) 時,輸出 tibble 包含一個名為 .predictions 的列表列,其中包含網格和每個折疊中每個參數組合的 out-of-sample 預測(可能非常大)。

tibble 的元素是 tibbles,其中包含調整參數的列、原始數據對象 ( .row ) 的行號、結果數據(與原始數據具有相同的名稱)以及由的預測。例如,對於簡單的回歸問題,此函數會生成一個名為.pred 的列,依此類推。如上所述,返回的預測列由請求的度量類型確定。

此列表列可以是 unnested 使用 tidyr::unnest() 或使用便利函數 collect_predictions()

提取信息

extract 控製選項將導致返回一個名為 .extracts 的附加函數。這是一個列表列,其中包含每個調整參數組合的用戶函數結果的標題。這可以允許返回在重采樣期間創建的每個模型和/或配方對象。請注意,這可能會導致返回對象很大,具體取決於返回的內容。

控製函數包含一個選項 (extract),可用於保留在重采樣中創建的任何模型或配方。該參數應該是具有單個參數的函數。每次重新采樣中賦予函數的參數值是工作流對象(有關更多信息,請參閱workflows::workflow())。可以使用多個輔助函數輕鬆地從工作流程中提取預處理和/或模型信息,例如 extract_preprocessor()extract_fit_parsnip()

例如,如果有興趣恢複每個防風草模型,可以使用:


  extract = function (x) extract_fit_parsnip(x)

請注意,賦予 extract 參數的函數是在每個適合的模型上評估的(而不是在評估的每個模型上)。如上所述,在某些情況下,可以針對 sub-models 導出模型預測,因此在這些情況下,並非調整參數網格中的每一行都有與其關聯的單獨 R 對象。

例子

library(recipes)
library(rsample)
library(parsnip)
library(workflows)

set.seed(6735)
folds <- vfold_cv(mtcars, v = 5)

spline_rec <- recipe(mpg ~ ., data = mtcars) %>%
  step_ns(disp) %>%
  step_ns(wt)

lin_mod <- linear_reg() %>%
  set_engine("lm")

control <- control_resamples(save_pred = TRUE)

spline_res <- fit_resamples(lin_mod, spline_rec, folds, control = control)

spline_res
#> # Resampling results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 5
#>   splits         id    .metrics         .notes           .predictions    
#>   <list>         <chr> <list>           <list>           <list>          
#> 1 <split [25/7]> Fold1 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [7 × 4]>
#> 2 <split [25/7]> Fold2 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [7 × 4]>
#> 3 <split [26/6]> Fold3 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [6 × 4]>
#> 4 <split [26/6]> Fold4 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [6 × 4]>
#> 5 <split [26/6]> Fold5 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [6 × 4]>

show_best(spline_res, metric = "rmse")
#> # A tibble: 1 × 6
#>   .metric .estimator  mean     n std_err .config             
#>   <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 rmse    standard    3.11     5   0.168 Preprocessor1_Model1

# You can also wrap up a preprocessor and a model into a workflow, and
# supply that to `fit_resamples()` instead. Here, a workflows "variables"
# preprocessor is used, which lets you supply terms using dplyr selectors.
# The variables are used as-is, no preprocessing is done to them.
wf <- workflow() %>%
  add_variables(outcomes = mpg, predictors = everything()) %>%
  add_model(lin_mod)

wf_res <- fit_resamples(wf, folds)
源代碼:R/resample.R

相關用法


注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Fit multiple models via resampling。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。