当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R tune collect_predictions 获取并格式化由调整函数产生的结果


获取并格式化由调整函数产生的结果

用法

collect_predictions(x, ...)

# S3 method for default
collect_predictions(x, ...)

# S3 method for tune_results
collect_predictions(x, summarize = FALSE, parameters = NULL, ...)

collect_metrics(x, ...)

# S3 method for tune_results
collect_metrics(x, summarize = TRUE, ...)

collect_notes(x, ...)

# S3 method for tune_results
collect_notes(x, ...)

collect_extracts(x, ...)

# S3 method for tune_results
collect_extracts(x, ...)

参数

x

tune_grid()tune_bayes()fit_resamples()last_fit() 的结果。对于collect_predictions(),应该使用控制选项save_pred = TRUE

...

目前未使用。

summarize

逻辑性强;应该通过重新采样(TRUE)汇总指标,还是返回每个单独重新采样的值。请注意,如果 xlast_fit() 创建,则 summarize 无效。对于其他对象类型,总结预测的方法详述如下。

parameters

可选的调整参数值小标题,可用于在处理之前过滤预测值。该小标题应该只包含每个调整参数标识符的列(例如,如果使用tune("my_param"),则为"my_param")。

一点点。列名称取决于结果和模型的模式。

对于 collect_metrics()collect_predictions() ,未汇总时,每个调整参数都有列(使用 tune() 中的 id(如果有))。 collect_metrics() 还具有列 .metric.estimator 。汇总结果时,会出现 meannstd_err 列。未汇总时,重采样标识符和 .estimate 的附加列。

对于 collect_predictions() ,还有用于重采样标识符的附加列、用于预测值的列(例如 .pred.pred_class 等)以及用于使用原始列的结果的列数据中的名称。

collect_predictions() 可以总结重复 out-of-sample 预测的各种结果。例如,当使用引导程序时,原始训练集中的每一行都有多个保留预测(跨评估集)。为了将这些结果转换为每个训练集都具有单个预测值的格式,需要对重复预测的结果进行平均。

对于回归情况,只需对数值预测进行平均。对于分类模型来说,问题更加复杂。当使用类别概率时,会对它们进行平均,然后重新标准化以确保它们相加为 1。如果数据中也存在硬类预测,则这些预测是根据汇总的概率估计确定的(以便它们匹配)。如果结果中只有硬类预测,则使用该模式进行总结。

collect_notes() 返回一个小标题,其中包含重采样指示器、位置(预处理器、模型等)、类型(错误或警告)和注释的列。

collect_extracts() 返回一个 tibble,其中包含重采样指示器的列、位置(预处理器、模型等)以及通过 control functionsextract 参数从工作流中提取的对象。

例子

data("example_ames_knn")
# The parameters for the model:
extract_parameter_set_dials(ames_wflow)
#> Collection of 5 parameters for tuning
#> 
#>   identifier        type    object
#>            K   neighbors nparam[+]
#>  weight_func weight_func dparam[+]
#>   dist_power  dist_power nparam[+]
#>          lon    deg_free nparam[+]
#>          lat    deg_free nparam[+]
#> 

# Summarized over resamples
collect_metrics(ames_grid_search)
#> # A tibble: 20 × 11
#>        K weight_func  dist_power   lon   lat .metric .estimator   mean
#>    <int> <chr>             <dbl> <int> <int> <chr>   <chr>       <dbl>
#>  1    35 optimal           1.32      8     1 rmse    standard   0.0785
#>  2    35 optimal           1.32      8     1 rsq     standard   0.823 
#>  3    35 rank              1.29      3    13 rmse    standard   0.0809
#>  4    35 rank              1.29      3    13 rsq     standard   0.814 
#>  5    21 cos               0.626     1     4 rmse    standard   0.0746
#>  6    21 cos               0.626     1     4 rsq     standard   0.836 
#>  7     4 biweight          0.311     8     4 rmse    standard   0.0777
#>  8     4 biweight          0.311     8     4 rsq     standard   0.814 
#>  9    32 triangular        0.165     9    15 rmse    standard   0.0770
#> 10    32 triangular        0.165     9    15 rsq     standard   0.826 
#> 11     3 rank              1.86     10    15 rmse    standard   0.0875
#> 12     3 rank              1.86     10    15 rsq     standard   0.762 
#> 13    40 triangular        0.167    11     7 rmse    standard   0.0778
#> 14    40 triangular        0.167    11     7 rsq     standard   0.822 
#> 15    12 epanechnikov      1.53      4     7 rmse    standard   0.0774
#> 16    12 epanechnikov      1.53      4     7 rsq     standard   0.820 
#> 17     5 rank              0.411     2     7 rmse    standard   0.0740
#> 18     5 rank              0.411     2     7 rsq     standard   0.833 
#> 19    33 triweight         0.511    10     3 rmse    standard   0.0728
#> 20    33 triweight         0.511    10     3 rsq     standard   0.842 
#> # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>

# Per-resample values
collect_metrics(ames_grid_search, summarize = FALSE)
#> # A tibble: 200 × 10
#>    id         K weight_func dist_power   lon   lat .metric .estimator
#>    <chr>  <int> <chr>            <dbl> <int> <int> <chr>   <chr>     
#>  1 Fold01    35 optimal           1.32     8     1 rmse    standard  
#>  2 Fold01    35 optimal           1.32     8     1 rsq     standard  
#>  3 Fold02    35 optimal           1.32     8     1 rmse    standard  
#>  4 Fold02    35 optimal           1.32     8     1 rsq     standard  
#>  5 Fold03    35 optimal           1.32     8     1 rmse    standard  
#>  6 Fold03    35 optimal           1.32     8     1 rsq     standard  
#>  7 Fold04    35 optimal           1.32     8     1 rmse    standard  
#>  8 Fold04    35 optimal           1.32     8     1 rsq     standard  
#>  9 Fold05    35 optimal           1.32     8     1 rmse    standard  
#> 10 Fold05    35 optimal           1.32     8     1 rsq     standard  
#> # ℹ 190 more rows
#> # ℹ 2 more variables: .estimate <dbl>, .config <chr>


# ---------------------------------------------------------------------------

library(parsnip)
library(rsample)
library(dplyr)
#> 
#> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’:
#> 
#>     filter, lag
#> The following objects are masked from ‘package:base’:
#> 
#>     intersect, setdiff, setequal, union
library(recipes)
#> 
#> Attaching package: ‘recipes’
#> The following object is masked from ‘package:stats’:
#> 
#>     step
library(tibble)

lm_mod <- linear_reg() %>% set_engine("lm")
set.seed(93599150)
car_folds <- vfold_cv(mtcars, v = 2, repeats = 3)
ctrl <- control_resamples(save_pred = TRUE, extract = extract_fit_engine)

spline_rec <-
  recipe(mpg ~ ., data = mtcars) %>%
  step_ns(disp, deg_free = tune("df"))

grid <- tibble(df = 3:6)

resampled <-
  lm_mod %>%
  tune_grid(spline_rec, resamples = car_folds, control = ctrl, grid = grid)

collect_predictions(resampled) %>% arrange(.row)
#> # A tibble: 384 × 7
#>    id      id2   .pred  .row    df   mpg .config             
#>    <chr>   <chr> <dbl> <int> <int> <dbl> <chr>               
#>  1 Repeat1 Fold2  16.5     1     3    21 Preprocessor1_Model1
#>  2 Repeat2 Fold1  19.0     1     3    21 Preprocessor1_Model1
#>  3 Repeat3 Fold1  20.0     1     3    21 Preprocessor1_Model1
#>  4 Repeat1 Fold2  15.1     1     4    21 Preprocessor2_Model1
#>  5 Repeat2 Fold1  17.7     1     4    21 Preprocessor2_Model1
#>  6 Repeat3 Fold1  20.1     1     4    21 Preprocessor2_Model1
#>  7 Repeat1 Fold2  17.9     1     5    21 Preprocessor3_Model1
#>  8 Repeat2 Fold1  18.3     1     5    21 Preprocessor3_Model1
#>  9 Repeat3 Fold1  20.4     1     5    21 Preprocessor3_Model1
#> 10 Repeat1 Fold2  15.1     1     6    21 Preprocessor4_Model1
#> # ℹ 374 more rows
collect_predictions(resampled, summarize = TRUE) %>% arrange(.row)
#> # A tibble: 128 × 5
#>     .row    df   mpg .config              .pred
#>    <int> <int> <dbl> <chr>                <dbl>
#>  1     1     3  21   Preprocessor1_Model1  18.5
#>  2     1     4  21   Preprocessor2_Model1  17.6
#>  3     1     5  21   Preprocessor3_Model1  18.9
#>  4     1     6  21   Preprocessor4_Model1  16.7
#>  5     2     3  21   Preprocessor1_Model1  19.4
#>  6     2     4  21   Preprocessor2_Model1  19.0
#>  7     2     5  21   Preprocessor3_Model1  18.7
#>  8     2     6  21   Preprocessor4_Model1  16.4
#>  9     3     3  22.8 Preprocessor1_Model1  31.8
#> 10     3     4  22.8 Preprocessor2_Model1  23.8
#> # ℹ 118 more rows
collect_predictions(resampled, summarize = TRUE, grid[1, ]) %>% arrange(.row)
#> # A tibble: 32 × 5
#>     .row    df   mpg .config              .pred
#>    <int> <int> <dbl> <chr>                <dbl>
#>  1     1     3  21   Preprocessor1_Model1  18.5
#>  2     2     3  21   Preprocessor1_Model1  19.4
#>  3     3     3  22.8 Preprocessor1_Model1  31.8
#>  4     4     3  21.4 Preprocessor1_Model1  20.2
#>  5     5     3  18.7 Preprocessor1_Model1  18.4
#>  6     6     3  18.1 Preprocessor1_Model1  20.6
#>  7     7     3  14.3 Preprocessor1_Model1  13.5
#>  8     8     3  24.4 Preprocessor1_Model1  19.2
#>  9     9     3  22.8 Preprocessor1_Model1  34.8
#> 10    10     3  19.2 Preprocessor1_Model1  16.6
#> # ℹ 22 more rows

collect_extracts(resampled)
#> # A tibble: 24 × 5
#>    id      id2      df .extracts .config             
#>    <chr>   <chr> <int> <list>    <chr>               
#>  1 Repeat1 Fold1     3 <lm>      Preprocessor1_Model1
#>  2 Repeat1 Fold1     4 <lm>      Preprocessor2_Model1
#>  3 Repeat1 Fold1     5 <lm>      Preprocessor3_Model1
#>  4 Repeat1 Fold1     6 <lm>      Preprocessor4_Model1
#>  5 Repeat1 Fold2     3 <lm>      Preprocessor1_Model1
#>  6 Repeat1 Fold2     4 <lm>      Preprocessor2_Model1
#>  7 Repeat1 Fold2     5 <lm>      Preprocessor3_Model1
#>  8 Repeat1 Fold2     6 <lm>      Preprocessor4_Model1
#>  9 Repeat2 Fold1     3 <lm>      Preprocessor1_Model1
#> 10 Repeat2 Fold1     4 <lm>      Preprocessor2_Model1
#> # ℹ 14 more rows
源代码:R/collect.R

相关用法


注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Obtain and format results produced by tuning functions。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。