當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


R tune tune_grid 通過網格搜索進行模型調整


tune_grid() 為一組預定義的調整參數計算一組性能指標(例如,準確性或 RMSE),這些參數對應於一次或多次數據重采樣的模型或配方。

用法

tune_grid(object, ...)

# S3 method for model_spec
tune_grid(
  object,
  preprocessor,
  resamples,
  ...,
  param_info = NULL,
  grid = 10,
  metrics = NULL,
  control = control_grid()
)

# S3 method for workflow
tune_grid(
  object,
  resamples,
  ...,
  param_info = NULL,
  grid = 10,
  metrics = NULL,
  control = control_grid()
)

參數

object

parsnip 模型規範或 workflows::workflow()

...

目前未使用。

preprocessor

使用 recipes::recipe() 創建的傳統模型公式或配方。

resamples

rset() 對象。

param_info

dials::parameters() 對象或 NULL 。如果沒有給出,則從其他參數派生參數集。當需要自定義參數範圍時,傳遞此參數可能很有用。

grid

調諧組合或正整數的 DataFrame 。 DataFrame 應具有用於調整每個參數的列和用於調整候選參數的行。整數表示要自動創建的候選參數集的數量。

metrics

一個 yardstick::metric_set()NULL

control

用於修改調整過程的對象。

resamples 的更新版本,帶有 .metrics.notes 的額外列表列(可選列是 .predictions.extracts )。 .notes

包含執行期間發生的警告和錯誤。

細節

假設有 m 個調整參數組合。 tune_grid() 可能不需要每次重采樣都適合所有 m 個模型/配方。例如:

  • 如果可以使用單個模型擬合來預測網格中的不同參數值,則僅使用一種擬合。例如,對於某些提升樹,如果請求 100 次提升迭代,則可以使用 100 次迭代的模型對象對小於 100 次的迭代進行預測(如果所有其他參數都相等)。

  • 當結合預處理和/或後處理參數調整模型時,將使用最小擬合次數。例如,如果配方步驟中的 PCA 組件數量在三個值(以及模型調整參數)上進行調整,則僅訓練三個配方。另一種方法是為每個模型調整參數多次重新訓練相同的配方。

這裏使用foreach包。要並行執行重采樣迭代,請注冊並行後端函數。有關示例,請參閱 foreach::foreach() 的文檔。

大多數情況下,訓練期間生成的警告會在發生時顯示,並與 control_grid(verbose = TRUE) 時的特定重新采樣相關聯。它們(通常)直到處理結束才聚合。

參數網格

如果未提供調整網格,則會使用 10 個候選參數組合創建半隨機網格(通過 dials::grid_latin_hypercube() )。

如果提供,網格應具有每個參數的列名稱,並且這些名稱應由參數名稱或 id 命名。例如,如果使用 penalty = tune() 將參數標記為優化,則應該有一個名為 penalty 的列。如果使用可選標識符,例如 penalty = tune(id = 'lambda') ,則相應的列名稱應為 lambda

在某些情況下,調整參數值取決於數據的維度。例如,隨機森林模型中的mtry 取決於預測變量的數量。在這種情況下,默認調整參數對象需要一個上限。 dials::finalize() 可用於導出數據相關參數。否則,可以創建參數集(通過 dials::parameters() ),並使用 dials update() 函數來更改值。此更新的參數集可以通過 param_info 參數傳遞給函數。

性能指標

要使用您自己的性能指標,可以使用 yardstick::metric_set() 函數來選擇每個模型應測量的內容。如果需要多個指標,可以將它們捆綁在一起。例如,要估計 ROC 曲線下的麵積以及靈敏度和特異性(在典型概率截止值 0.50 下),可以給出 metrics 參數:


  metrics = metric_set(roc_auc, sens, spec)

每個指標都是針對每個候選模型計算的。

如果未提供指標集,則會創建一個指標集:

  • 對於回歸模型,計算均方根誤差和確定係數。

  • 對於分類,計算 ROC 曲線下的麵積和總體準確度。

請注意,這些指標還決定了調整期間估計的預測類型。例如,在分類問題中,如果使用的度量全部與硬類預測相關,則不會創建分類概率。

這些指標的 out-of-sample 估計值包含在名為 .metrics 的列表列中。該小標題包含每個指標的行和值、估計器類型等的列。

collect_metrics() 可用於這些對象來折疊重采樣的結果(以獲得每個調整參數組合的最終重采樣估計)。

獲取預測

control_grid(save_pred = TRUE) 時,輸出 tibble 包含一個名為 .predictions 的列表列,其中包含網格和每個折疊中每個參數組合的 out-of-sample 預測(可能非常大)。

tibble 的元素是 tibbles,其中包含調整參數的列、原始數據對象 ( .row ) 的行號、結果數據(與原始數據具有相同的名稱)以及由的預測。例如,對於簡單的回歸問題,此函數會生成一個名為.pred 的列,依此類推。如上所述,返回的預測列由請求的度量類型確定。

此列表列可以是 unnested 使用 tidyr::unnest() 或使用便利函數 collect_predictions()

提取信息

extract 控製選項將導致返回一個名為 .extracts 的附加函數。這是一個列表列,其中包含每個調整參數組合的用戶函數結果的標題。這可以允許返回在重采樣期間創建的每個模型和/或配方對象。請注意,這可能會導致返回對象很大,具體取決於返回的內容。

控製函數包含一個選項 (extract),可用於保留在重采樣中創建的任何模型或配方。該參數應該是具有單個參數的函數。每次重新采樣中賦予函數的參數值是工作流對象(有關更多信息,請參閱workflows::workflow())。可以使用多個輔助函數輕鬆地從工作流程中提取預處理和/或模型信息,例如 extract_preprocessor()extract_fit_parsnip()

例如,如果有興趣恢複每個防風草模型,可以使用:


  extract = function (x) extract_fit_parsnip(x)

請注意,賦予 extract 參數的函數是在每個適合的模型上評估的(而不是在評估的每個模型上)。如上所述,在某些情況下,可以針對 sub-models 導出模型預測,因此在這些情況下,並非調整參數網格中的每一行都有與其關聯的單獨 R 對象。

例子

library(recipes)
library(rsample)
library(parsnip)
library(workflows)
library(ggplot2)

# ---------------------------------------------------------------------------

set.seed(6735)
folds <- vfold_cv(mtcars, v = 5)

# ---------------------------------------------------------------------------

# tuning recipe parameters:

spline_rec <-
  recipe(mpg ~ ., data = mtcars) %>%
  step_ns(disp, deg_free = tune("disp")) %>%
  step_ns(wt, deg_free = tune("wt"))

lin_mod <-
  linear_reg() %>%
  set_engine("lm")

# manually create a grid
spline_grid <- expand.grid(disp = 2:5, wt = 2:5)

# Warnings will occur from making spline terms on the holdout data that are
# extrapolations.
spline_res <-
  tune_grid(lin_mod, spline_rec, resamples = folds, grid = spline_grid)
spline_res
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics          .notes          
#>   <list>         <chr> <list>            <list>          
#> 1 <split [25/7]> Fold1 <tibble [32 × 6]> <tibble [0 × 3]>
#> 2 <split [25/7]> Fold2 <tibble [32 × 6]> <tibble [0 × 3]>
#> 3 <split [26/6]> Fold3 <tibble [32 × 6]> <tibble [0 × 3]>
#> 4 <split [26/6]> Fold4 <tibble [32 × 6]> <tibble [0 × 3]>
#> 5 <split [26/6]> Fold5 <tibble [32 × 6]> <tibble [0 × 3]>


show_best(spline_res, metric = "rmse")
#> # A tibble: 5 × 8
#>    disp    wt .metric .estimator  mean     n std_err .config              
#>   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
#> 1     3     2 rmse    standard    2.54     5   0.207 Preprocessor02_Model1
#> 2     3     3 rmse    standard    2.64     5   0.234 Preprocessor06_Model1
#> 3     4     3 rmse    standard    2.82     5   0.456 Preprocessor07_Model1
#> 4     4     2 rmse    standard    2.93     5   0.489 Preprocessor03_Model1
#> 5     4     4 rmse    standard    3.01     5   0.475 Preprocessor11_Model1

# ---------------------------------------------------------------------------

# tune model parameters only (example requires the `kernlab` package)

car_rec <-
  recipe(mpg ~ ., data = mtcars) %>%
  step_normalize(all_predictors())

svm_mod <-
  svm_rbf(cost = tune(), rbf_sigma = tune()) %>%
  set_engine("kernlab") %>%
  set_mode("regression")

# Use a space-filling design with 7 points
set.seed(3254)
svm_res <- tune_grid(svm_mod, car_rec, resamples = folds, grid = 7)
svm_res
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics          .notes          
#>   <list>         <chr> <list>            <list>          
#> 1 <split [25/7]> Fold1 <tibble [14 × 6]> <tibble [0 × 3]>
#> 2 <split [25/7]> Fold2 <tibble [14 × 6]> <tibble [0 × 3]>
#> 3 <split [26/6]> Fold3 <tibble [14 × 6]> <tibble [0 × 3]>
#> 4 <split [26/6]> Fold4 <tibble [14 × 6]> <tibble [0 × 3]>
#> 5 <split [26/6]> Fold5 <tibble [14 × 6]> <tibble [0 × 3]>

show_best(svm_res, metric = "rmse")
#> # A tibble: 5 × 8
#>       cost   rbf_sigma .metric .estimator  mean     n std_err .config     
#>      <dbl>       <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>       
#> 1  0.304   0.117       rmse    standard    3.91     5   0.652 Preprocesso…
#> 2  4.53    0.000420    rmse    standard    4.13     5   0.741 Preprocesso…
#> 3  0.00247 0.00931     rmse    standard    5.94     5   0.966 Preprocesso…
#> 4 23.2     0.000000684 rmse    standard    5.94     5   0.967 Preprocesso…
#> 5  0.0126  0.00000239  rmse    standard    5.96     5   0.970 Preprocesso…

autoplot(svm_res, metric = "rmse") +
  scale_x_log10()
#> Warning: NaNs produced
#> Warning: Transformation introduced infinite values in continuous x-axis
#> Warning: Removed 12 rows containing missing values (`geom_point()`).


# ---------------------------------------------------------------------------

# Using a variables preprocessor with a workflow

# Rather than supplying a preprocessor (like a recipe) and a model directly
# to `tune_grid()`, you can also wrap them up in a workflow and pass
# that along instead (note that this doesn't do any preprocessing to
# the variables, it passes them along as-is).
wf <- workflow() %>%
  add_variables(outcomes = mpg, predictors = everything()) %>%
  add_model(svm_mod)

set.seed(3254)
svm_res_wf <- tune_grid(wf, resamples = folds, grid = 7)
源代碼:R/tune_grid.R

相關用法


注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Model tuning via grid search。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。