当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R tune tune_grid 通过网格搜索进行模型调整


tune_grid() 为一组预定义的调整参数计算一组性能指标(例如,准确性或 RMSE),这些参数对应于一次或多次数据重采样的模型或配方。

用法

tune_grid(object, ...)

# S3 method for model_spec
tune_grid(
  object,
  preprocessor,
  resamples,
  ...,
  param_info = NULL,
  grid = 10,
  metrics = NULL,
  control = control_grid()
)

# S3 method for workflow
tune_grid(
  object,
  resamples,
  ...,
  param_info = NULL,
  grid = 10,
  metrics = NULL,
  control = control_grid()
)

参数

object

parsnip 模型规范或 workflows::workflow()

...

目前未使用。

preprocessor

使用 recipes::recipe() 创建的传统模型公式或配方。

resamples

rset() 对象。

param_info

dials::parameters() 对象或 NULL 。如果没有给出,则从其他参数派生参数集。当需要自定义参数范围时,传递此参数可能很有用。

grid

调谐组合或正整数的 DataFrame 。 DataFrame 应具有用于调整每个参数的列和用于调整候选参数的行。整数表示要自动创建的候选参数集的数量。

metrics

一个 yardstick::metric_set()NULL

control

用于修改调整过程的对象。

resamples 的更新版本,带有 .metrics.notes 的额外列表列(可选列是 .predictions.extracts )。 .notes

包含执行期间发生的警告和错误。

细节

假设有 m 个调整参数组合。 tune_grid() 可能不需要每次重采样都适合所有 m 个模型/配方。例如:

  • 如果可以使用单个模型拟合来预测网格中的不同参数值,则仅使用一种拟合。例如,对于某些提升树,如果请求 100 次提升迭代,则可以使用 100 次迭代的模型对象对小于 100 次的迭代进行预测(如果所有其他参数都相等)。

  • 当结合预处理和/或后处理参数调整模型时,将使用最小拟合次数。例如,如果配方步骤中的 PCA 组件数量在三个值(以及模型调整参数)上进行调整,则仅训练三个配方。另一种方法是为每个模型调整参数多次重新训练相同的配方。

这里使用foreach包。要并行执行重采样迭代,请注册并行后端函数。有关示例,请参阅 foreach::foreach() 的文档。

大多数情况下,训练期间生成的警告会在发生时显示,并与 control_grid(verbose = TRUE) 时的特定重新采样相关联。它们(通常)直到处理结束才聚合。

参数网格

如果未提供调整网格,则会使用 10 个候选参数组合创建半随机网格(通过 dials::grid_latin_hypercube() )。

如果提供,网格应具有每个参数的列名称,并且这些名称应由参数名称或 id 命名。例如,如果使用 penalty = tune() 将参数标记为优化,则应该有一个名为 penalty 的列。如果使用可选标识符,例如 penalty = tune(id = 'lambda') ,则相应的列名称应为 lambda

在某些情况下,调整参数值取决于数据的维度。例如,随机森林模型中的mtry 取决于预测变量的数量。在这种情况下,默认调整参数对象需要一个上限。 dials::finalize() 可用于导出数据相关参数。否则,可以创建参数集(通过 dials::parameters() ),并使用 dials update() 函数来更改值。此更新的参数集可以通过 param_info 参数传递给函数。

性能指标

要使用您自己的性能指标,可以使用 yardstick::metric_set() 函数来选择每个模型应测量的内容。如果需要多个指标,可以将它们捆绑在一起。例如,要估计 ROC 曲线下的面积以及灵敏度和特异性(在典型概率截止值 0.50 下),可以给出 metrics 参数:


  metrics = metric_set(roc_auc, sens, spec)

每个指标都是针对每个候选模型计算的。

如果未提供指标集,则会创建一个指标集:

  • 对于回归模型,计算均方根误差和确定系数。

  • 对于分类,计算 ROC 曲线下的面积和总体准确度。

请注意,这些指标还决定了调整期间估计的预测类型。例如,在分类问题中,如果使用的度量全部与硬类预测相关,则不会创建分类概率。

这些指标的 out-of-sample 估计值包含在名为 .metrics 的列表列中。该小标题包含每个指标的行和值、估计器类型等的列。

collect_metrics() 可用于这些对象来折叠重采样的结果(以获得每个调整参数组合的最终重采样估计)。

获取预测

control_grid(save_pred = TRUE) 时,输出 tibble 包含一个名为 .predictions 的列表列,其中包含网格和每个折叠中每个参数组合的 out-of-sample 预测(可能非常大)。

tibble 的元素是 tibbles,其中包含调整参数的列、原始数据对象 ( .row ) 的行号、结果数据(与原始数据具有相同的名称)以及由的预测。例如,对于简单的回归问题,此函数会生成一个名为.pred 的列,依此类推。如上所述,返回的预测列由请求的度量类型确定。

此列表列可以是 unnested 使用 tidyr::unnest() 或使用便利函数 collect_predictions()

提取信息

extract 控制选项将导致返回一个名为 .extracts 的附加函数。这是一个列表列,其中包含每个调整参数组合的用户函数结果的标题。这可以允许返回在重采样期间创建的每个模型和/或配方对象。请注意,这可能会导致返回对象很大,具体取决于返回的内容。

控制函数包含一个选项 (extract),可用于保留在重采样中创建的任何模型或配方。该参数应该是具有单个参数的函数。每次重新采样中赋予函数的参数值是工作流对象(有关更多信息,请参阅workflows::workflow())。可以使用多个辅助函数轻松地从工作流程中提取预处理和/或模型信息,例如 extract_preprocessor()extract_fit_parsnip()

例如,如果有兴趣恢复每个防风草模型,可以使用:


  extract = function (x) extract_fit_parsnip(x)

请注意,赋予 extract 参数的函数是在每个适合的模型上评估的(而不是在评估的每个模型上)。如上所述,在某些情况下,可以针对 sub-models 导出模型预测,因此在这些情况下,并非调整参数网格中的每一行都有与其关联的单独 R 对象。

例子

library(recipes)
library(rsample)
library(parsnip)
library(workflows)
library(ggplot2)

# ---------------------------------------------------------------------------

set.seed(6735)
folds <- vfold_cv(mtcars, v = 5)

# ---------------------------------------------------------------------------

# tuning recipe parameters:

spline_rec <-
  recipe(mpg ~ ., data = mtcars) %>%
  step_ns(disp, deg_free = tune("disp")) %>%
  step_ns(wt, deg_free = tune("wt"))

lin_mod <-
  linear_reg() %>%
  set_engine("lm")

# manually create a grid
spline_grid <- expand.grid(disp = 2:5, wt = 2:5)

# Warnings will occur from making spline terms on the holdout data that are
# extrapolations.
spline_res <-
  tune_grid(lin_mod, spline_rec, resamples = folds, grid = spline_grid)
spline_res
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics          .notes          
#>   <list>         <chr> <list>            <list>          
#> 1 <split [25/7]> Fold1 <tibble [32 × 6]> <tibble [0 × 3]>
#> 2 <split [25/7]> Fold2 <tibble [32 × 6]> <tibble [0 × 3]>
#> 3 <split [26/6]> Fold3 <tibble [32 × 6]> <tibble [0 × 3]>
#> 4 <split [26/6]> Fold4 <tibble [32 × 6]> <tibble [0 × 3]>
#> 5 <split [26/6]> Fold5 <tibble [32 × 6]> <tibble [0 × 3]>


show_best(spline_res, metric = "rmse")
#> # A tibble: 5 × 8
#>    disp    wt .metric .estimator  mean     n std_err .config              
#>   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
#> 1     3     2 rmse    standard    2.54     5   0.207 Preprocessor02_Model1
#> 2     3     3 rmse    standard    2.64     5   0.234 Preprocessor06_Model1
#> 3     4     3 rmse    standard    2.82     5   0.456 Preprocessor07_Model1
#> 4     4     2 rmse    standard    2.93     5   0.489 Preprocessor03_Model1
#> 5     4     4 rmse    standard    3.01     5   0.475 Preprocessor11_Model1

# ---------------------------------------------------------------------------

# tune model parameters only (example requires the `kernlab` package)

car_rec <-
  recipe(mpg ~ ., data = mtcars) %>%
  step_normalize(all_predictors())

svm_mod <-
  svm_rbf(cost = tune(), rbf_sigma = tune()) %>%
  set_engine("kernlab") %>%
  set_mode("regression")

# Use a space-filling design with 7 points
set.seed(3254)
svm_res <- tune_grid(svm_mod, car_rec, resamples = folds, grid = 7)
svm_res
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics          .notes          
#>   <list>         <chr> <list>            <list>          
#> 1 <split [25/7]> Fold1 <tibble [14 × 6]> <tibble [0 × 3]>
#> 2 <split [25/7]> Fold2 <tibble [14 × 6]> <tibble [0 × 3]>
#> 3 <split [26/6]> Fold3 <tibble [14 × 6]> <tibble [0 × 3]>
#> 4 <split [26/6]> Fold4 <tibble [14 × 6]> <tibble [0 × 3]>
#> 5 <split [26/6]> Fold5 <tibble [14 × 6]> <tibble [0 × 3]>

show_best(svm_res, metric = "rmse")
#> # A tibble: 5 × 8
#>       cost   rbf_sigma .metric .estimator  mean     n std_err .config     
#>      <dbl>       <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>       
#> 1  0.304   0.117       rmse    standard    3.91     5   0.652 Preprocesso…
#> 2  4.53    0.000420    rmse    standard    4.13     5   0.741 Preprocesso…
#> 3  0.00247 0.00931     rmse    standard    5.94     5   0.966 Preprocesso…
#> 4 23.2     0.000000684 rmse    standard    5.94     5   0.967 Preprocesso…
#> 5  0.0126  0.00000239  rmse    standard    5.96     5   0.970 Preprocesso…

autoplot(svm_res, metric = "rmse") +
  scale_x_log10()
#> Warning: NaNs produced
#> Warning: Transformation introduced infinite values in continuous x-axis
#> Warning: Removed 12 rows containing missing values (`geom_point()`).


# ---------------------------------------------------------------------------

# Using a variables preprocessor with a workflow

# Rather than supplying a preprocessor (like a recipe) and a model directly
# to `tune_grid()`, you can also wrap them up in a workflow and pass
# that along instead (note that this doesn't do any preprocessing to
# the variables, it passes them along as-is).
wf <- workflow() %>%
  add_variables(outcomes = mpg, predictors = everything()) %>%
  add_model(svm_mod)

set.seed(3254)
svm_res_wf <- tune_grid(wf, resamples = folds, grid = 7)
源代码:R/tune_grid.R

相关用法


注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Model tuning via grid search。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。