当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R tune tune_bayes 模型参数的贝叶斯优化。


tune_bayes() 使用模型根据之前的结果生成新的候选调整参数组合。

用法

tune_bayes(object, ...)

# S3 method for model_spec
tune_bayes(
  object,
  preprocessor,
  resamples,
  ...,
  iter = 10,
  param_info = NULL,
  metrics = NULL,
  objective = exp_improve(),
  initial = 5,
  control = control_bayes()
)

# S3 method for workflow
tune_bayes(
  object,
  resamples,
  ...,
  iter = 10,
  param_info = NULL,
  metrics = NULL,
  objective = exp_improve(),
  initial = 5,
  control = control_bayes()
)

参数

object

parsnip 模型规范或 workflows::workflow()

...

传递给 GPfit::GP_fit() 的选项(主要用于 corr 参数)。

preprocessor

使用 recipes::recipe() 创建的传统模型公式或配方。

resamples

rset() 对象。

iter

搜索迭代的最大次数。

param_info

dials::parameters() 对象或 NULL 。如果没有给出,则从其他参数派生参数集。当需要自定义参数范围时,传递此参数可能很有用。

metrics

yardstick::metric_set() 对象,包含有关如何评估模型性能的信息。 metrics 中的第一个指标是要优化的指标。

objective

应该优化哪个指标的字符串或获取函数对象。

initial

一组采用整齐格式的初始结果(如 tune_grid() 的结果)或正整数。建议初始结果的数量大于正在优化的参数的数量。

control

control_bayes()创建的控制对象

反映 tune_grid() 生成的结果的一小部分结果。但是,这些结果包含 .iter 列并复制 rset

在迭代中多次对象(以有限的额外内存成本)。

细节

优化从一组初始结果开始,例如 tune_grid() 生成的结果。如果不存在,该函数将创建多个组合并获得它们的性能估计。

使用性能估计之一作为模型结果,创建高斯过程 (GP) 模型,其中使用先前的调整参数组合作为预测变量。

使用该模型预测潜在超参数组合的大网格,并使用获取函数进行评分。这些函数通常结合 GP 的预测均值和方差来决定下一步要尝试的最佳参数组合。有关更多信息,请参阅 exp_improve() 的文档和相应的包插图。

使用重采样评估最佳组合,并继续该过程。

并行处理

这里使用foreach包。要并行执行重采样迭代,请注册并行后端函数。有关示例,请参阅 foreach::foreach() 的文档。

大多数情况下,训练期间生成的警告会在发生时显示,并与 control_bayes(verbose = TRUE) 时的特定重新采样相关联。它们(通常)直到处理结束才聚合。

对于贝叶斯优化,一旦估计出新的候选值集,就使用并行处理来估计重采样的性能值。

初始值

tune_grid() 的结果或之前运行的 tune_bayes() 可以在 initial 参数中使用。 initial也可以是正整数。在这种情况下,space-filling 设计将用于填充一组初步结果。为了获得好的结果,初始值的数量应该多于正在优化的参数的数量。

参数范围和值

在某些情况下,调整参数值取决于数据的维度(据说它们包含unknown值)。例如,随机森林模型中的mtry 取决于预测变量的数量。在这种情况下,必须事先确定调整参数对象中的未知数并通过 param_info 参数传递给函数。 dials::finalize() 可用于导出数据相关参数。否则,可以通过 dials::parameters() 创建参数集,并使用 dials update() 函数指定范围或值。

性能指标

要使用您自己的性能指标,可以使用 yardstick::metric_set() 函数来选择每个模型应测量的内容。如果需要多个指标,可以将它们捆绑在一起。例如,要估计 ROC 曲线下的面积以及灵敏度和特异性(在典型概率截止值 0.50 下),可以给出 metrics 参数:


  metrics = metric_set(roc_auc, sens, spec)

每个指标都是针对每个候选模型计算的。

如果未提供指标集,则会创建一个指标集:

  • 对于回归模型,计算均方根误差和确定系数。

  • 对于分类,计算 ROC 曲线下的面积和总体准确度。

请注意,这些指标还决定了调整期间估计的预测类型。例如,在分类问题中,如果使用的度量全部与硬类预测相关,则不会创建分类概率。

这些指标的 out-of-sample 估计值包含在名为 .metrics 的列表列中。该小标题包含每个指标的行和值、估计器类型等的列。

collect_metrics() 可用于这些对象来折叠重采样的结果(以获得每个调整参数组合的最终重采样估计)。

获取预测

control_bayes(save_pred = TRUE) 时,输出 tibble 包含一个名为 .predictions 的列表列,其中包含网格和每个折叠中每个参数组合的 out-of-sample 预测(可能非常大)。

tibble 的元素是 tibbles,其中包含调整参数的列、原始数据对象 ( .row ) 的行号、结果数据(与原始数据具有相同的名称)以及由的预测。例如,对于简单的回归问题,此函数会生成一个名为.pred 的列,依此类推。如上所述,返回的预测列由请求的度量类型确定。

此列表列可以是 unnested 使用 tidyr::unnest() 或使用便利函数 collect_predictions()

提取信息

extract 控制选项将导致返回一个名为 .extracts 的附加函数。这是一个列表列,其中包含每个调整参数组合的用户函数结果的标题。这可以允许返回在重采样期间创建的每个模型和/或配方对象。请注意,这可能会导致返回对象很大,具体取决于返回的内容。

控制函数包含一个选项 (extract),可用于保留在重采样中创建的任何模型或配方。该参数应该是具有单个参数的函数。每次重新采样中赋予函数的参数值是工作流对象(有关更多信息,请参阅workflows::workflow())。可以使用多个辅助函数轻松地从工作流程中提取预处理和/或模型信息,例如 extract_preprocessor()extract_fit_parsnip()

例如,如果有兴趣恢复每个防风草模型,可以使用:


  extract = function (x) extract_fit_parsnip(x)

请注意,赋予 extract 参数的函数是在每个适合的模型上评估的(而不是在评估的每个模型上)。如上所述,在某些情况下,可以针对 sub-models 导出模型预测,因此在这些情况下,并非调整参数网格中的每一行都有与其关联的单独 R 对象。

例子

library(recipes)
library(rsample)
library(parsnip)

# define resamples and minimal recipe on mtcars
set.seed(6735)
folds <- vfold_cv(mtcars, v = 5)

car_rec <-
  recipe(mpg ~ ., data = mtcars) %>%
  step_normalize(all_predictors())

# define an svm with parameters to tune
svm_mod <-
  svm_rbf(cost = tune(), rbf_sigma = tune()) %>%
  set_engine("kernlab") %>%
  set_mode("regression")

# use a space-filling design with 6 points
set.seed(3254)
svm_grid <- tune_grid(svm_mod, car_rec, folds, grid = 6)

show_best(svm_grid, metric = "rmse")
#> # A tibble: 5 × 8
#>       cost    rbf_sigma .metric .estimator  mean     n std_err .config    
#>      <dbl>        <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>      
#> 1 25.3     0.248        rmse    standard    3.17     5   0.678 Preprocess…
#> 2  3.88    0.000510     rmse    standard    4.07     5   0.727 Preprocess…
#> 3  0.102   0.00000592   rmse    standard    5.96     5   0.970 Preprocess…
#> 4  0.00125 0.000138     rmse    standard    5.96     5   0.970 Preprocess…
#> 5  0.0192  0.0000000427 rmse    standard    5.96     5   0.970 Preprocess…

# use bayesian optimization to evaluate at 6 more points
set.seed(8241)
svm_bayes <- tune_bayes(svm_mod, car_rec, folds, initial = svm_grid, iter = 6)

# note that bayesian optimization evaluated parameterizations
# similar to those that previously decreased rmse in svm_grid
show_best(svm_bayes, metric = "rmse")
#> # A tibble: 5 × 9
#>    cost rbf_sigma .metric .estimator  mean     n std_err .config .iter
#>   <dbl>     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>   <int>
#> 1 31.6    0.00144 rmse    standard    2.60     5   0.232 Iter1       1
#> 2 28.7    0.00292 rmse    standard    2.61     5   0.208 Iter4       4
#> 3 28.3    0.00685 rmse    standard    2.62     5   0.195 Iter3       3
#> 4 31.4    0.00482 rmse    standard    2.64     5   0.183 Iter5       5
#> 5  7.30   0.0533  rmse    standard    2.71     5   0.303 Iter6       6

# specifying `initial` as a numeric rather than previous tuning results
# will result in `tune_bayes` initially evaluating an space-filling
# grid using `tune_grid` with `grid = initial`
set.seed(0239)
svm_init <- tune_bayes(svm_mod, car_rec, folds, initial = 6, iter = 6)

show_best(svm_init, metric = "rmse")
#> # A tibble: 5 × 9
#>    cost rbf_sigma .metric .estimator  mean     n std_err .config .iter
#>   <dbl>     <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>   <int>
#> 1 2.35     0.0269 rmse    standard    2.70     5   0.188 Iter5       5
#> 2 3.32     0.0361 rmse    standard    2.72     5   0.192 Iter6       6
#> 3 1.40     0.0479 rmse    standard    2.78     5   0.229 Iter4       4
#> 4 0.509    0.0256 rmse    standard    3.17     5   0.473 Iter3       3
#> 5 0.256    0.0201 rmse    standard    3.79     5   0.622 Iter2       2
源代码:R/tune_bayes.R

相关用法


注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Bayesian optimization of model parameters.。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。