当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R tune filter_parameters 删除一些调整参数结果


对于 tune_*() 函数生成的对象,可能只有感兴趣的调整参数组合的子集。对于大型数据集,能够删除一些结果可能会有所帮助。此函数会修剪 .metrics 列中不需要的结果以及 .predictions.extracts 列(如果需要的话)。

用法

filter_parameters(x, ..., parameters = NULL)

参数

x

具有多个调整参数的tune_results 类的对象。

...

返回逻辑值的表达式,并根据调整参数值进行定义。如果包含多个表达式,它们将与 & 运算符组合。仅保留所有条件评估为 TRUE 的行。

parameters

调整参数值的小标题,可用于在处理之前过滤预测值。该小标题应该只包含用于调整参数标识符的列(例如,如果使用tune("my_param"),则为"my_param")。可以有多行和一列或多列。如果使用,则必须命名该参数。

x 的一个版本,其中列表列仅保留 parameters 中的参数组合或满足过滤逻辑。

细节

删除某些参数组合可能会影响对象的 autoplot() 结果。

例子

library(dplyr)
library(tibble)

# For grid search:
data("example_ames_knn")

## -----------------------------------------------------------------------------
# select all combinations using the 'rank' weighting scheme

ames_grid_search %>%
  collect_metrics()
#> # A tibble: 20 × 11
#>        K weight_func  dist_power   lon   lat .metric .estimator   mean
#>    <int> <chr>             <dbl> <int> <int> <chr>   <chr>       <dbl>
#>  1    35 optimal           1.32      8     1 rmse    standard   0.0785
#>  2    35 optimal           1.32      8     1 rsq     standard   0.823 
#>  3    35 rank              1.29      3    13 rmse    standard   0.0809
#>  4    35 rank              1.29      3    13 rsq     standard   0.814 
#>  5    21 cos               0.626     1     4 rmse    standard   0.0746
#>  6    21 cos               0.626     1     4 rsq     standard   0.836 
#>  7     4 biweight          0.311     8     4 rmse    standard   0.0777
#>  8     4 biweight          0.311     8     4 rsq     standard   0.814 
#>  9    32 triangular        0.165     9    15 rmse    standard   0.0770
#> 10    32 triangular        0.165     9    15 rsq     standard   0.826 
#> 11     3 rank              1.86     10    15 rmse    standard   0.0875
#> 12     3 rank              1.86     10    15 rsq     standard   0.762 
#> 13    40 triangular        0.167    11     7 rmse    standard   0.0778
#> 14    40 triangular        0.167    11     7 rsq     standard   0.822 
#> 15    12 epanechnikov      1.53      4     7 rmse    standard   0.0774
#> 16    12 epanechnikov      1.53      4     7 rsq     standard   0.820 
#> 17     5 rank              0.411     2     7 rmse    standard   0.0740
#> 18     5 rank              0.411     2     7 rsq     standard   0.833 
#> 19    33 triweight         0.511    10     3 rmse    standard   0.0728
#> 20    33 triweight         0.511    10     3 rsq     standard   0.842 
#> # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>

filter_parameters(ames_grid_search, weight_func == "rank") %>%
  collect_metrics()
#> # A tibble: 6 × 11
#>       K weight_func dist_power   lon   lat .metric .estimator   mean     n
#>   <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl> <int>
#> 1    35 rank             1.29      3    13 rmse    standard   0.0809    10
#> 2    35 rank             1.29      3    13 rsq     standard   0.814     10
#> 3     3 rank             1.86     10    15 rmse    standard   0.0875    10
#> 4     3 rank             1.86     10    15 rsq     standard   0.762     10
#> 5     5 rank             0.411     2     7 rmse    standard   0.0740    10
#> 6     5 rank             0.411     2     7 rsq     standard   0.833     10
#> # ℹ 2 more variables: std_err <dbl>, .config <chr>

rank_only <- tibble::tibble(weight_func = "rank")
filter_parameters(ames_grid_search, parameters = rank_only) %>%
  collect_metrics()
#> # A tibble: 6 × 11
#>       K weight_func dist_power   lon   lat .metric .estimator   mean     n
#>   <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl> <int>
#> 1    35 rank             1.29      3    13 rmse    standard   0.0809    10
#> 2    35 rank             1.29      3    13 rsq     standard   0.814     10
#> 3     3 rank             1.86     10    15 rmse    standard   0.0875    10
#> 4     3 rank             1.86     10    15 rsq     standard   0.762     10
#> 5     5 rank             0.411     2     7 rmse    standard   0.0740    10
#> 6     5 rank             0.411     2     7 rsq     standard   0.833     10
#> # ℹ 2 more variables: std_err <dbl>, .config <chr>

## -----------------------------------------------------------------------------
# Keep only the results from the numerically best combination

ames_iter_search %>%
  collect_metrics()
#> # A tibble: 40 × 12
#>        K weight_func dist_power   lon   lat .metric .estimator   mean
#>    <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl>
#>  1    35 optimal          1.32      8     1 rmse    standard   0.0785
#>  2    35 optimal          1.32      8     1 rsq     standard   0.823 
#>  3    35 rank             1.29      3    13 rmse    standard   0.0809
#>  4    35 rank             1.29      3    13 rsq     standard   0.814 
#>  5    21 cos              0.626     1     4 rmse    standard   0.0746
#>  6    21 cos              0.626     1     4 rsq     standard   0.836 
#>  7     4 biweight         0.311     8     4 rmse    standard   0.0777
#>  8     4 biweight         0.311     8     4 rsq     standard   0.814 
#>  9    32 triangular       0.165     9    15 rmse    standard   0.0770
#> 10    32 triangular       0.165     9    15 rsq     standard   0.826 
#> # ℹ 30 more rows
#> # ℹ 4 more variables: n <int>, std_err <dbl>, .config <chr>, .iter <int>

best_param <- select_best(ames_iter_search, metric = "rmse")
ames_iter_search %>%
  filter_parameters(parameters = best_param) %>%
  collect_metrics()
#> # A tibble: 2 × 12
#>       K weight_func dist_power   lon   lat .metric .estimator   mean     n
#>   <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl> <int>
#> 1    33 triweight        0.511    10     3 rmse    standard   0.0728    10
#> 2    33 triweight        0.511    10     3 rsq     standard   0.842     10
#> # ℹ 3 more variables: std_err <dbl>, .config <chr>, .iter <int>

相关用法


注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Remove some tuning parameter results。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。