當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


R tune collect_predictions 獲取並格式化由調整函數產生的結果


獲取並格式化由調整函數產生的結果

用法

collect_predictions(x, ...)

# S3 method for default
collect_predictions(x, ...)

# S3 method for tune_results
collect_predictions(x, summarize = FALSE, parameters = NULL, ...)

collect_metrics(x, ...)

# S3 method for tune_results
collect_metrics(x, summarize = TRUE, ...)

collect_notes(x, ...)

# S3 method for tune_results
collect_notes(x, ...)

collect_extracts(x, ...)

# S3 method for tune_results
collect_extracts(x, ...)

參數

x

tune_grid()tune_bayes()fit_resamples()last_fit() 的結果。對於collect_predictions(),應該使用控製選項save_pred = TRUE

...

目前未使用。

summarize

邏輯性強;應該通過重新采樣(TRUE)匯總指標,還是返回每個單獨重新采樣的值。請注意,如果 xlast_fit() 創建,則 summarize 無效。對於其他對象類型,總結預測的方法詳述如下。

parameters

可選的調整參數值小標題,可用於在處理之前過濾預測值。該小標題應該隻包含每個調整參數標識符的列(例如,如果使用tune("my_param"),則為"my_param")。

一點點。列名稱取決於結果和模型的模式。

對於 collect_metrics()collect_predictions() ,未匯總時,每個調整參數都有列(使用 tune() 中的 id(如果有))。 collect_metrics() 還具有列 .metric.estimator 。匯總結果時,會出現 meannstd_err 列。未匯總時,重采樣標識符和 .estimate 的附加列。

對於 collect_predictions() ,還有用於重采樣標識符的附加列、用於預測值的列(例如 .pred.pred_class 等)以及用於使用原始列的結果的列數據中的名稱。

collect_predictions() 可以總結重複 out-of-sample 預測的各種結果。例如,當使用引導程序時,原始訓練集中的每一行都有多個保留預測(跨評估集)。為了將這些結果轉換為每個訓練集都具有單個預測值的格式,需要對重複預測的結果進行平均。

對於回歸情況,隻需對數值預測進行平均。對於分類模型來說,問題更加複雜。當使用類別概率時,會對它們進行平均,然後重新標準化以確保它們相加為 1。如果數據中也存在硬類預測,則這些預測是根據匯總的概率估計確定的(以便它們匹配)。如果結果中隻有硬類預測,則使用該模式進行總結。

collect_notes() 返回一個小標題,其中包含重采樣指示器、位置(預處理器、模型等)、類型(錯誤或警告)和注釋的列。

collect_extracts() 返回一個 tibble,其中包含重采樣指示器的列、位置(預處理器、模型等)以及通過 control functionsextract 參數從工作流中提取的對象。

例子

data("example_ames_knn")
# The parameters for the model:
extract_parameter_set_dials(ames_wflow)
#> Collection of 5 parameters for tuning
#> 
#>   identifier        type    object
#>            K   neighbors nparam[+]
#>  weight_func weight_func dparam[+]
#>   dist_power  dist_power nparam[+]
#>          lon    deg_free nparam[+]
#>          lat    deg_free nparam[+]
#> 

# Summarized over resamples
collect_metrics(ames_grid_search)
#> # A tibble: 20 × 11
#>        K weight_func  dist_power   lon   lat .metric .estimator   mean
#>    <int> <chr>             <dbl> <int> <int> <chr>   <chr>       <dbl>
#>  1    35 optimal           1.32      8     1 rmse    standard   0.0785
#>  2    35 optimal           1.32      8     1 rsq     standard   0.823 
#>  3    35 rank              1.29      3    13 rmse    standard   0.0809
#>  4    35 rank              1.29      3    13 rsq     standard   0.814 
#>  5    21 cos               0.626     1     4 rmse    standard   0.0746
#>  6    21 cos               0.626     1     4 rsq     standard   0.836 
#>  7     4 biweight          0.311     8     4 rmse    standard   0.0777
#>  8     4 biweight          0.311     8     4 rsq     standard   0.814 
#>  9    32 triangular        0.165     9    15 rmse    standard   0.0770
#> 10    32 triangular        0.165     9    15 rsq     standard   0.826 
#> 11     3 rank              1.86     10    15 rmse    standard   0.0875
#> 12     3 rank              1.86     10    15 rsq     standard   0.762 
#> 13    40 triangular        0.167    11     7 rmse    standard   0.0778
#> 14    40 triangular        0.167    11     7 rsq     standard   0.822 
#> 15    12 epanechnikov      1.53      4     7 rmse    standard   0.0774
#> 16    12 epanechnikov      1.53      4     7 rsq     standard   0.820 
#> 17     5 rank              0.411     2     7 rmse    standard   0.0740
#> 18     5 rank              0.411     2     7 rsq     standard   0.833 
#> 19    33 triweight         0.511    10     3 rmse    standard   0.0728
#> 20    33 triweight         0.511    10     3 rsq     standard   0.842 
#> # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>

# Per-resample values
collect_metrics(ames_grid_search, summarize = FALSE)
#> # A tibble: 200 × 10
#>    id         K weight_func dist_power   lon   lat .metric .estimator
#>    <chr>  <int> <chr>            <dbl> <int> <int> <chr>   <chr>     
#>  1 Fold01    35 optimal           1.32     8     1 rmse    standard  
#>  2 Fold01    35 optimal           1.32     8     1 rsq     standard  
#>  3 Fold02    35 optimal           1.32     8     1 rmse    standard  
#>  4 Fold02    35 optimal           1.32     8     1 rsq     standard  
#>  5 Fold03    35 optimal           1.32     8     1 rmse    standard  
#>  6 Fold03    35 optimal           1.32     8     1 rsq     standard  
#>  7 Fold04    35 optimal           1.32     8     1 rmse    standard  
#>  8 Fold04    35 optimal           1.32     8     1 rsq     standard  
#>  9 Fold05    35 optimal           1.32     8     1 rmse    standard  
#> 10 Fold05    35 optimal           1.32     8     1 rsq     standard  
#> # ℹ 190 more rows
#> # ℹ 2 more variables: .estimate <dbl>, .config <chr>


# ---------------------------------------------------------------------------

library(parsnip)
library(rsample)
library(dplyr)
#> 
#> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’:
#> 
#>     filter, lag
#> The following objects are masked from ‘package:base’:
#> 
#>     intersect, setdiff, setequal, union
library(recipes)
#> 
#> Attaching package: ‘recipes’
#> The following object is masked from ‘package:stats’:
#> 
#>     step
library(tibble)

lm_mod <- linear_reg() %>% set_engine("lm")
set.seed(93599150)
car_folds <- vfold_cv(mtcars, v = 2, repeats = 3)
ctrl <- control_resamples(save_pred = TRUE, extract = extract_fit_engine)

spline_rec <-
  recipe(mpg ~ ., data = mtcars) %>%
  step_ns(disp, deg_free = tune("df"))

grid <- tibble(df = 3:6)

resampled <-
  lm_mod %>%
  tune_grid(spline_rec, resamples = car_folds, control = ctrl, grid = grid)

collect_predictions(resampled) %>% arrange(.row)
#> # A tibble: 384 × 7
#>    id      id2   .pred  .row    df   mpg .config             
#>    <chr>   <chr> <dbl> <int> <int> <dbl> <chr>               
#>  1 Repeat1 Fold2  16.5     1     3    21 Preprocessor1_Model1
#>  2 Repeat2 Fold1  19.0     1     3    21 Preprocessor1_Model1
#>  3 Repeat3 Fold1  20.0     1     3    21 Preprocessor1_Model1
#>  4 Repeat1 Fold2  15.1     1     4    21 Preprocessor2_Model1
#>  5 Repeat2 Fold1  17.7     1     4    21 Preprocessor2_Model1
#>  6 Repeat3 Fold1  20.1     1     4    21 Preprocessor2_Model1
#>  7 Repeat1 Fold2  17.9     1     5    21 Preprocessor3_Model1
#>  8 Repeat2 Fold1  18.3     1     5    21 Preprocessor3_Model1
#>  9 Repeat3 Fold1  20.4     1     5    21 Preprocessor3_Model1
#> 10 Repeat1 Fold2  15.1     1     6    21 Preprocessor4_Model1
#> # ℹ 374 more rows
collect_predictions(resampled, summarize = TRUE) %>% arrange(.row)
#> # A tibble: 128 × 5
#>     .row    df   mpg .config              .pred
#>    <int> <int> <dbl> <chr>                <dbl>
#>  1     1     3  21   Preprocessor1_Model1  18.5
#>  2     1     4  21   Preprocessor2_Model1  17.6
#>  3     1     5  21   Preprocessor3_Model1  18.9
#>  4     1     6  21   Preprocessor4_Model1  16.7
#>  5     2     3  21   Preprocessor1_Model1  19.4
#>  6     2     4  21   Preprocessor2_Model1  19.0
#>  7     2     5  21   Preprocessor3_Model1  18.7
#>  8     2     6  21   Preprocessor4_Model1  16.4
#>  9     3     3  22.8 Preprocessor1_Model1  31.8
#> 10     3     4  22.8 Preprocessor2_Model1  23.8
#> # ℹ 118 more rows
collect_predictions(resampled, summarize = TRUE, grid[1, ]) %>% arrange(.row)
#> # A tibble: 32 × 5
#>     .row    df   mpg .config              .pred
#>    <int> <int> <dbl> <chr>                <dbl>
#>  1     1     3  21   Preprocessor1_Model1  18.5
#>  2     2     3  21   Preprocessor1_Model1  19.4
#>  3     3     3  22.8 Preprocessor1_Model1  31.8
#>  4     4     3  21.4 Preprocessor1_Model1  20.2
#>  5     5     3  18.7 Preprocessor1_Model1  18.4
#>  6     6     3  18.1 Preprocessor1_Model1  20.6
#>  7     7     3  14.3 Preprocessor1_Model1  13.5
#>  8     8     3  24.4 Preprocessor1_Model1  19.2
#>  9     9     3  22.8 Preprocessor1_Model1  34.8
#> 10    10     3  19.2 Preprocessor1_Model1  16.6
#> # ℹ 22 more rows

collect_extracts(resampled)
#> # A tibble: 24 × 5
#>    id      id2      df .extracts .config             
#>    <chr>   <chr> <int> <list>    <chr>               
#>  1 Repeat1 Fold1     3 <lm>      Preprocessor1_Model1
#>  2 Repeat1 Fold1     4 <lm>      Preprocessor2_Model1
#>  3 Repeat1 Fold1     5 <lm>      Preprocessor3_Model1
#>  4 Repeat1 Fold1     6 <lm>      Preprocessor4_Model1
#>  5 Repeat1 Fold2     3 <lm>      Preprocessor1_Model1
#>  6 Repeat1 Fold2     4 <lm>      Preprocessor2_Model1
#>  7 Repeat1 Fold2     5 <lm>      Preprocessor3_Model1
#>  8 Repeat1 Fold2     6 <lm>      Preprocessor4_Model1
#>  9 Repeat2 Fold1     3 <lm>      Preprocessor1_Model1
#> 10 Repeat2 Fold1     4 <lm>      Preprocessor2_Model1
#> # ℹ 14 more rows
源代碼:R/collect.R

相關用法


注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Obtain and format results produced by tuning functions。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。