當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


R tune filter_parameters 刪除一些調整參數結果


對於 tune_*() 函數生成的對象,可能隻有感興趣的調整參數組合的子集。對於大型數據集,能夠刪除一些結果可能會有所幫助。此函數會修剪 .metrics 列中不需要的結果以及 .predictions.extracts 列(如果需要的話)。

用法

filter_parameters(x, ..., parameters = NULL)

參數

x

具有多個調整參數的tune_results 類的對象。

...

返回邏輯值的表達式,並根據調整參數值進行定義。如果包含多個表達式,它們將與 & 運算符組合。僅保留所有條件評估為 TRUE 的行。

parameters

調整參數值的小標題,可用於在處理之前過濾預測值。該小標題應該隻包含用於調整參數標識符的列(例如,如果使用tune("my_param"),則為"my_param")。可以有多行和一列或多列。如果使用,則必須命名該參數。

x 的一個版本,其中列表列僅保留 parameters 中的參數組合或滿足過濾邏輯。

細節

刪除某些參數組合可能會影響對象的 autoplot() 結果。

例子

library(dplyr)
library(tibble)

# For grid search:
data("example_ames_knn")

## -----------------------------------------------------------------------------
# select all combinations using the 'rank' weighting scheme

ames_grid_search %>%
  collect_metrics()
#> # A tibble: 20 × 11
#>        K weight_func  dist_power   lon   lat .metric .estimator   mean
#>    <int> <chr>             <dbl> <int> <int> <chr>   <chr>       <dbl>
#>  1    35 optimal           1.32      8     1 rmse    standard   0.0785
#>  2    35 optimal           1.32      8     1 rsq     standard   0.823 
#>  3    35 rank              1.29      3    13 rmse    standard   0.0809
#>  4    35 rank              1.29      3    13 rsq     standard   0.814 
#>  5    21 cos               0.626     1     4 rmse    standard   0.0746
#>  6    21 cos               0.626     1     4 rsq     standard   0.836 
#>  7     4 biweight          0.311     8     4 rmse    standard   0.0777
#>  8     4 biweight          0.311     8     4 rsq     standard   0.814 
#>  9    32 triangular        0.165     9    15 rmse    standard   0.0770
#> 10    32 triangular        0.165     9    15 rsq     standard   0.826 
#> 11     3 rank              1.86     10    15 rmse    standard   0.0875
#> 12     3 rank              1.86     10    15 rsq     standard   0.762 
#> 13    40 triangular        0.167    11     7 rmse    standard   0.0778
#> 14    40 triangular        0.167    11     7 rsq     standard   0.822 
#> 15    12 epanechnikov      1.53      4     7 rmse    standard   0.0774
#> 16    12 epanechnikov      1.53      4     7 rsq     standard   0.820 
#> 17     5 rank              0.411     2     7 rmse    standard   0.0740
#> 18     5 rank              0.411     2     7 rsq     standard   0.833 
#> 19    33 triweight         0.511    10     3 rmse    standard   0.0728
#> 20    33 triweight         0.511    10     3 rsq     standard   0.842 
#> # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>

filter_parameters(ames_grid_search, weight_func == "rank") %>%
  collect_metrics()
#> # A tibble: 6 × 11
#>       K weight_func dist_power   lon   lat .metric .estimator   mean     n
#>   <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl> <int>
#> 1    35 rank             1.29      3    13 rmse    standard   0.0809    10
#> 2    35 rank             1.29      3    13 rsq     standard   0.814     10
#> 3     3 rank             1.86     10    15 rmse    standard   0.0875    10
#> 4     3 rank             1.86     10    15 rsq     standard   0.762     10
#> 5     5 rank             0.411     2     7 rmse    standard   0.0740    10
#> 6     5 rank             0.411     2     7 rsq     standard   0.833     10
#> # ℹ 2 more variables: std_err <dbl>, .config <chr>

rank_only <- tibble::tibble(weight_func = "rank")
filter_parameters(ames_grid_search, parameters = rank_only) %>%
  collect_metrics()
#> # A tibble: 6 × 11
#>       K weight_func dist_power   lon   lat .metric .estimator   mean     n
#>   <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl> <int>
#> 1    35 rank             1.29      3    13 rmse    standard   0.0809    10
#> 2    35 rank             1.29      3    13 rsq     standard   0.814     10
#> 3     3 rank             1.86     10    15 rmse    standard   0.0875    10
#> 4     3 rank             1.86     10    15 rsq     standard   0.762     10
#> 5     5 rank             0.411     2     7 rmse    standard   0.0740    10
#> 6     5 rank             0.411     2     7 rsq     standard   0.833     10
#> # ℹ 2 more variables: std_err <dbl>, .config <chr>

## -----------------------------------------------------------------------------
# Keep only the results from the numerically best combination

ames_iter_search %>%
  collect_metrics()
#> # A tibble: 40 × 12
#>        K weight_func dist_power   lon   lat .metric .estimator   mean
#>    <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl>
#>  1    35 optimal          1.32      8     1 rmse    standard   0.0785
#>  2    35 optimal          1.32      8     1 rsq     standard   0.823 
#>  3    35 rank             1.29      3    13 rmse    standard   0.0809
#>  4    35 rank             1.29      3    13 rsq     standard   0.814 
#>  5    21 cos              0.626     1     4 rmse    standard   0.0746
#>  6    21 cos              0.626     1     4 rsq     standard   0.836 
#>  7     4 biweight         0.311     8     4 rmse    standard   0.0777
#>  8     4 biweight         0.311     8     4 rsq     standard   0.814 
#>  9    32 triangular       0.165     9    15 rmse    standard   0.0770
#> 10    32 triangular       0.165     9    15 rsq     standard   0.826 
#> # ℹ 30 more rows
#> # ℹ 4 more variables: n <int>, std_err <dbl>, .config <chr>, .iter <int>

best_param <- select_best(ames_iter_search, metric = "rmse")
ames_iter_search %>%
  filter_parameters(parameters = best_param) %>%
  collect_metrics()
#> # A tibble: 2 × 12
#>       K weight_func dist_power   lon   lat .metric .estimator   mean     n
#>   <int> <chr>            <dbl> <int> <int> <chr>   <chr>       <dbl> <int>
#> 1    33 triweight        0.511    10     3 rmse    standard   0.0728    10
#> 2    33 triweight        0.511    10     3 rsq     standard   0.842     10
#> # ℹ 3 more variables: std_err <dbl>, .config <chr>, .iter <int>

相關用法


注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Remove some tuning parameter results。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。