當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


R tune tune_bayes 模型參數的貝葉斯優化。

tune_bayes() 使用模型根據之前的結果生成新的候選調整參數組合。

用法

tune_bayes(object, ...)

# S3 method for model_spec
tune_bayes(
  object,
  preprocessor,
  resamples,
  ...,
  iter = 10,
  param_info = NULL,
  metrics = NULL,
  objective = exp_improve(),
  initial = 5,
  control = control_bayes()
)

# S3 method for workflow
tune_bayes(
  object,
  resamples,
  ...,
  iter = 10,
  param_info = NULL,
  metrics = NULL,
  objective = exp_improve(),
  initial = 5,
  control = control_bayes()
)

參數

object

parsnip 模型規範或 workflows::workflow()

...

傳遞給 GPfit::GP_fit() 的選項(主要用於 corr 參數)。

preprocessor

使用 recipes::recipe() 創建的傳統模型公式或配方。

resamples

rset() 對象。

iter

搜索迭代的最大次數。

param_info

dials::parameters() 對象或 NULL 。如果沒有給出,則從其他參數派生參數集。當需要自定義參數範圍時,傳遞此參數可能很有用。

metrics

yardstick::metric_set() 對象,包含有關如何評估模型性能的信息。 metrics 中的第一個指標是要優化的指標。

objective

應該優化哪個指標的字符串或獲取函數對象。

initial

一組采用整齊格式的初始結果(如 tune_grid() 的結果)或正整數。建議初始結果的數量大於正在優化的參數的數量。

control

control_bayes()創建的控製對象

反映 tune_grid() 生成的結果的一小部分結果。但是,這些結果包含 .iter 列並複製 rset

在迭代中多次對象(以有限的額外內存成本)。

細節

優化從一組初始結果開始,例如 tune_grid() 生成的結果。如果不存在,該函數將創建多個組合並獲得它們的性能估計。

使用性能估計之一作為模型結果,創建高斯過程 (GP) 模型,其中使用先前的調整參數組合作為預測變量。

使用該模型預測潛在超參數組合的大網格,並使用獲取函數進行評分。這些函數通常結合 GP 的預測均值和方差來決定下一步要嘗試的最佳參數組合。有關更多信息,請參閱 exp_improve() 的文檔和相應的包插圖。

使用重采樣評估最佳組合,並繼續該過程。

並行處理

這裏使用foreach包。要並行執行重采樣迭代,請注冊並行後端函數。有關示例,請參閱 foreach::foreach() 的文檔。

大多數情況下,訓練期間生成的警告會在發生時顯示,並與 control_bayes(verbose = TRUE) 時的特定重新采樣相關聯。它們(通常)直到處理結束才聚合。

對於貝葉斯優化,一旦估計出新的候選值集,就使用並行處理來估計重采樣的性能值。

初始值

tune_grid() 的結果或之前運行的 tune_bayes() 可以在 initial 參數中使用。 initial也可以是正整數。在這種情況下,space-filling 設計將用於填充一組初步結果。為了獲得好的結果,初始值的數量應該多於正在優化的參數的數量。

參數範圍和值

在某些情況下,調整參數值取決於數據的維度(據說它們包含unknown值)。例如,隨機森林模型中的mtry 取決於預測變量的數量。在這種情況下,必須事先確定調整參數對象中的未知數並通過 param_info 參數傳遞給函數。 dials::finalize() 可用於導出數據相關參數。否則,可以通過 dials::parameters() 創建參數集,並使用 dials update() 函數指定範圍或值。

性能指標

要使用您自己的性能指標,可以使用 yardstick::metric_set() 函數來選擇每個模型應測量的內容。如果需要多個指標,可以將它們捆綁在一起。例如,要估計 ROC 曲線下的麵積以及靈敏度和特異性(在典型概率截止值 0.50 下),可以給出 metrics 參數:


  metrics = metric_set(roc_auc, sens, spec)

每個指標都是針對每個候選模型計算的。

如果未提供指標集,則會創建一個指標集:

  • 對於回歸模型,計算均方根誤差和確定係數。

  • 對於分類,計算 ROC 曲線下的麵積和總體準確度。

請注意,這些指標還決定了調整期間估計的預測類型。例如,在分類問題中,如果使用的度量全部與硬類預測相關,則不會創建分類概率。

這些指標的 out-of-sample 估計值包含在名為 .metrics 的列表列中。該小標題包含每個指標的行和值、估計器類型等的列。

collect_metrics() 可用於這些對象來折疊重采樣的結果(以獲得每個調整參數組合的最終重采樣估計)。

獲取預測

control_bayes(save_pred = TRUE) 時,輸出 tibble 包含一個名為 .predictions 的列表列,其中包含網格和每個折疊中每個參數組合的 out-of-sample 預測(可能非常大)。

tibble 的元素是 tibbles,其中包含調整參數的列、原始數據對象 ( .row ) 的行號、結果數據(與原始數據具有相同的名稱)以及由的預測。例如,對於簡單的回歸問題,此函數會生成一個名為.pred 的列,依此類推。如上所述,返回的預測列由請求的度量類型確定。

此列表列可以是 unnested 使用 tidyr::unnest() 或使用便利函數 collect_predictions()

提取信息

extract 控製選項將導致返回一個名為 .extracts 的附加函數。這是一個列表列,其中包含每個調整參數組合的用戶函數結果的標題。這可以允許返回在重采樣期間創建的每個模型和/或配方對象。請注意,這可能會導致返回對象很大,具體取決於返回的內容。

控製函數包含一個選項 (extract),可用於保留在重采樣中創建的任何模型或配方。該參數應該是具有單個參數的函數。每次重新采樣中賦予函數的參數值是工作流對象(有關更多信息,請參閱workflows::workflow())。可以使用多個輔助函數輕鬆地從工作流程中提取預處理和/或模型信息,例如 extract_preprocessor()extract_fit_parsnip()

例如,如果有興趣恢複每個防風草模型,可以使用:


  extract = function (x) extract_fit_parsnip(x)

請注意,賦予 extract 參數的函數是在每個適合的模型上評估的(而不是在評估的每個模型上)。如上所述,在某些情況下,可以針對 sub-models 導出模型預測,因此在這些情況下,並非調整參數網格中的每一行都有與其關聯的單獨 R 對象。

例子

library(recipes)
library(rsample)
library(parsnip)

# define resamples and minimal recipe on mtcars
set.seed(6735)
folds <- vfold_cv(mtcars, v = 5)

car_rec <-
  recipe(mpg ~ ., data = mtcars) %>%
  step_normalize(all_predictors())

# define an svm with parameters to tune
svm_mod <-
  svm_rbf(cost = tune(), rbf_sigma = tune()) %>%
  set_engine("kernlab") %>%
  set_mode("regression")

# use a space-filling design with 6 points
set.seed(3254)
svm_grid <- tune_grid(svm_mod, car_rec, folds, grid = 6)

show_best(svm_grid, metric = "rmse")
#> # A tibble: 5 × 8
#>       cost    rbf_sigma .metric .estimator  mean     n std_err .config    
#>      <dbl>        <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>      
#> 1 25.3     0.248        rmse    standard    3.17     5   0.678 Preprocess…
#> 2  3.88    0.000510     rmse    standard    4.07     5   0.727 Preprocess…
#> 3  0.102   0.00000592   rmse    standard    5.96     5   0.970 Preprocess…
#> 4  0.00125 0.000138     rmse    standard    5.96     5   0.970 Preprocess…
#> 5  0.0192  0.0000000427 rmse    standard    5.96     5   0.970 Preprocess…

# use bayesian optimization to evaluate at 6 more points
set.seed(8241)
svm_bayes <- tune_bayes(svm_mod, car_rec, folds, initial = svm_grid, iter = 6)

# note that bayesian optimization evaluated parameterizations
# similar to those that previously decreased rmse in svm_grid
show_best(svm_bayes, metric = "rmse")
#> # A tibble: 5 × 9
#>    cost rbf_sigma .metric .estimator  mean     n std_err .config .iter
#>   <dbl>     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>   <int>
#> 1 31.6    0.00144 rmse    standard    2.60     5   0.232 Iter1       1
#> 2 28.7    0.00292 rmse    standard    2.61     5   0.208 Iter4       4
#> 3 28.3    0.00685 rmse    standard    2.62     5   0.195 Iter3       3
#> 4 31.4    0.00482 rmse    standard    2.64     5   0.183 Iter5       5
#> 5  7.30   0.0533  rmse    standard    2.71     5   0.303 Iter6       6

# specifying `initial` as a numeric rather than previous tuning results
# will result in `tune_bayes` initially evaluating an space-filling
# grid using `tune_grid` with `grid = initial`
set.seed(0239)
svm_init <- tune_bayes(svm_mod, car_rec, folds, initial = 6, iter = 6)

show_best(svm_init, metric = "rmse")
#> # A tibble: 5 × 9
#>    cost rbf_sigma .metric .estimator  mean     n std_err .config .iter
#>   <dbl>     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>   <int>
#> 1 2.35     0.0269 rmse    standard    2.70     5   0.188 Iter5       5
#> 2 3.32     0.0361 rmse    standard    2.72     5   0.192 Iter6       6
#> 3 1.40     0.0479 rmse    standard    2.78     5   0.229 Iter4       4
#> 4 0.509    0.0256 rmse    standard    3.17     5   0.473 Iter3       3
#> 5 0.256    0.0201 rmse    standard    3.79     5   0.622 Iter2       2
源代碼:R/tune_bayes.R

相關用法


注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Bayesian optimization of model parameters.。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。