获取并格式化由调整函数产生的结果
用法
collect_predictions(x, ...)
# S3 method for default
collect_predictions(x, ...)
# S3 method for tune_results
collect_predictions(x, summarize = FALSE, parameters = NULL, ...)
collect_metrics(x, ...)
# S3 method for tune_results
collect_metrics(x, summarize = TRUE, ...)
collect_notes(x, ...)
# S3 method for tune_results
collect_notes(x, ...)
collect_extracts(x, ...)
# S3 method for tune_results
collect_extracts(x, ...)
参数
- x
-
tune_grid()
、tune_bayes()
、fit_resamples()
或last_fit()
的结果。对于collect_predictions()
,应该使用控制选项save_pred = TRUE
。 - ...
-
目前未使用。
- summarize
-
逻辑性强;应该通过重新采样(
TRUE
)汇总指标,还是返回每个单独重新采样的值。请注意,如果x
由last_fit()
创建,则summarize
无效。对于其他对象类型,总结预测的方法详述如下。 - parameters
-
可选的调整参数值小标题,可用于在处理之前过滤预测值。该小标题应该只包含每个调整参数标识符的列(例如,如果使用
tune("my_param")
,则为"my_param"
)。
值
一点点。列名称取决于结果和模型的模式。
对于 collect_metrics()
和 collect_predictions()
,未汇总时,每个调整参数都有列(使用 tune()
中的 id
(如果有))。 collect_metrics()
还具有列 .metric
和 .estimator
。汇总结果时,会出现 mean
、 n
和 std_err
列。未汇总时,重采样标识符和 .estimate
的附加列。
对于 collect_predictions()
,还有用于重采样标识符的附加列、用于预测值的列(例如 .pred
、 .pred_class
等)以及用于使用原始列的结果的列数据中的名称。
collect_predictions()
可以总结重复 out-of-sample 预测的各种结果。例如,当使用引导程序时,原始训练集中的每一行都有多个保留预测(跨评估集)。为了将这些结果转换为每个训练集都具有单个预测值的格式,需要对重复预测的结果进行平均。
对于回归情况,只需对数值预测进行平均。对于分类模型来说,问题更加复杂。当使用类别概率时,会对它们进行平均,然后重新标准化以确保它们相加为 1。如果数据中也存在硬类预测,则这些预测是根据汇总的概率估计确定的(以便它们匹配)。如果结果中只有硬类预测,则使用该模式进行总结。
collect_notes()
返回一个小标题,其中包含重采样指示器、位置(预处理器、模型等)、类型(错误或警告)和注释的列。
collect_extracts()
返回一个 tibble,其中包含重采样指示器的列、位置(预处理器、模型等)以及通过 control functions 的 extract
参数从工作流中提取的对象。
例子
data("example_ames_knn")
# The parameters for the model:
extract_parameter_set_dials(ames_wflow)
#> Collection of 5 parameters for tuning
#>
#> identifier type object
#> K neighbors nparam[+]
#> weight_func weight_func dparam[+]
#> dist_power dist_power nparam[+]
#> lon deg_free nparam[+]
#> lat deg_free nparam[+]
#>
# Summarized over resamples
collect_metrics(ames_grid_search)
#> # A tibble: 20 × 11
#> K weight_func dist_power lon lat .metric .estimator mean
#> <int> <chr> <dbl> <int> <int> <chr> <chr> <dbl>
#> 1 35 optimal 1.32 8 1 rmse standard 0.0785
#> 2 35 optimal 1.32 8 1 rsq standard 0.823
#> 3 35 rank 1.29 3 13 rmse standard 0.0809
#> 4 35 rank 1.29 3 13 rsq standard 0.814
#> 5 21 cos 0.626 1 4 rmse standard 0.0746
#> 6 21 cos 0.626 1 4 rsq standard 0.836
#> 7 4 biweight 0.311 8 4 rmse standard 0.0777
#> 8 4 biweight 0.311 8 4 rsq standard 0.814
#> 9 32 triangular 0.165 9 15 rmse standard 0.0770
#> 10 32 triangular 0.165 9 15 rsq standard 0.826
#> 11 3 rank 1.86 10 15 rmse standard 0.0875
#> 12 3 rank 1.86 10 15 rsq standard 0.762
#> 13 40 triangular 0.167 11 7 rmse standard 0.0778
#> 14 40 triangular 0.167 11 7 rsq standard 0.822
#> 15 12 epanechnikov 1.53 4 7 rmse standard 0.0774
#> 16 12 epanechnikov 1.53 4 7 rsq standard 0.820
#> 17 5 rank 0.411 2 7 rmse standard 0.0740
#> 18 5 rank 0.411 2 7 rsq standard 0.833
#> 19 33 triweight 0.511 10 3 rmse standard 0.0728
#> 20 33 triweight 0.511 10 3 rsq standard 0.842
#> # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>
# Per-resample values
collect_metrics(ames_grid_search, summarize = FALSE)
#> # A tibble: 200 × 10
#> id K weight_func dist_power lon lat .metric .estimator
#> <chr> <int> <chr> <dbl> <int> <int> <chr> <chr>
#> 1 Fold01 35 optimal 1.32 8 1 rmse standard
#> 2 Fold01 35 optimal 1.32 8 1 rsq standard
#> 3 Fold02 35 optimal 1.32 8 1 rmse standard
#> 4 Fold02 35 optimal 1.32 8 1 rsq standard
#> 5 Fold03 35 optimal 1.32 8 1 rmse standard
#> 6 Fold03 35 optimal 1.32 8 1 rsq standard
#> 7 Fold04 35 optimal 1.32 8 1 rmse standard
#> 8 Fold04 35 optimal 1.32 8 1 rsq standard
#> 9 Fold05 35 optimal 1.32 8 1 rmse standard
#> 10 Fold05 35 optimal 1.32 8 1 rsq standard
#> # ℹ 190 more rows
#> # ℹ 2 more variables: .estimate <dbl>, .config <chr>
# ---------------------------------------------------------------------------
library(parsnip)
library(rsample)
library(dplyr)
#>
#> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’:
#>
#> filter, lag
#> The following objects are masked from ‘package:base’:
#>
#> intersect, setdiff, setequal, union
library(recipes)
#>
#> Attaching package: ‘recipes’
#> The following object is masked from ‘package:stats’:
#>
#> step
library(tibble)
lm_mod <- linear_reg() %>% set_engine("lm")
set.seed(93599150)
car_folds <- vfold_cv(mtcars, v = 2, repeats = 3)
ctrl <- control_resamples(save_pred = TRUE, extract = extract_fit_engine)
spline_rec <-
recipe(mpg ~ ., data = mtcars) %>%
step_ns(disp, deg_free = tune("df"))
grid <- tibble(df = 3:6)
resampled <-
lm_mod %>%
tune_grid(spline_rec, resamples = car_folds, control = ctrl, grid = grid)
collect_predictions(resampled) %>% arrange(.row)
#> # A tibble: 384 × 7
#> id id2 .pred .row df mpg .config
#> <chr> <chr> <dbl> <int> <int> <dbl> <chr>
#> 1 Repeat1 Fold2 16.5 1 3 21 Preprocessor1_Model1
#> 2 Repeat2 Fold1 19.0 1 3 21 Preprocessor1_Model1
#> 3 Repeat3 Fold1 20.0 1 3 21 Preprocessor1_Model1
#> 4 Repeat1 Fold2 15.1 1 4 21 Preprocessor2_Model1
#> 5 Repeat2 Fold1 17.7 1 4 21 Preprocessor2_Model1
#> 6 Repeat3 Fold1 20.1 1 4 21 Preprocessor2_Model1
#> 7 Repeat1 Fold2 17.9 1 5 21 Preprocessor3_Model1
#> 8 Repeat2 Fold1 18.3 1 5 21 Preprocessor3_Model1
#> 9 Repeat3 Fold1 20.4 1 5 21 Preprocessor3_Model1
#> 10 Repeat1 Fold2 15.1 1 6 21 Preprocessor4_Model1
#> # ℹ 374 more rows
collect_predictions(resampled, summarize = TRUE) %>% arrange(.row)
#> # A tibble: 128 × 5
#> .row df mpg .config .pred
#> <int> <int> <dbl> <chr> <dbl>
#> 1 1 3 21 Preprocessor1_Model1 18.5
#> 2 1 4 21 Preprocessor2_Model1 17.6
#> 3 1 5 21 Preprocessor3_Model1 18.9
#> 4 1 6 21 Preprocessor4_Model1 16.7
#> 5 2 3 21 Preprocessor1_Model1 19.4
#> 6 2 4 21 Preprocessor2_Model1 19.0
#> 7 2 5 21 Preprocessor3_Model1 18.7
#> 8 2 6 21 Preprocessor4_Model1 16.4
#> 9 3 3 22.8 Preprocessor1_Model1 31.8
#> 10 3 4 22.8 Preprocessor2_Model1 23.8
#> # ℹ 118 more rows
collect_predictions(resampled, summarize = TRUE, grid[1, ]) %>% arrange(.row)
#> # A tibble: 32 × 5
#> .row df mpg .config .pred
#> <int> <int> <dbl> <chr> <dbl>
#> 1 1 3 21 Preprocessor1_Model1 18.5
#> 2 2 3 21 Preprocessor1_Model1 19.4
#> 3 3 3 22.8 Preprocessor1_Model1 31.8
#> 4 4 3 21.4 Preprocessor1_Model1 20.2
#> 5 5 3 18.7 Preprocessor1_Model1 18.4
#> 6 6 3 18.1 Preprocessor1_Model1 20.6
#> 7 7 3 14.3 Preprocessor1_Model1 13.5
#> 8 8 3 24.4 Preprocessor1_Model1 19.2
#> 9 9 3 22.8 Preprocessor1_Model1 34.8
#> 10 10 3 19.2 Preprocessor1_Model1 16.6
#> # ℹ 22 more rows
collect_extracts(resampled)
#> # A tibble: 24 × 5
#> id id2 df .extracts .config
#> <chr> <chr> <int> <list> <chr>
#> 1 Repeat1 Fold1 3 <lm> Preprocessor1_Model1
#> 2 Repeat1 Fold1 4 <lm> Preprocessor2_Model1
#> 3 Repeat1 Fold1 5 <lm> Preprocessor3_Model1
#> 4 Repeat1 Fold1 6 <lm> Preprocessor4_Model1
#> 5 Repeat1 Fold2 3 <lm> Preprocessor1_Model1
#> 6 Repeat1 Fold2 4 <lm> Preprocessor2_Model1
#> 7 Repeat1 Fold2 5 <lm> Preprocessor3_Model1
#> 8 Repeat1 Fold2 6 <lm> Preprocessor4_Model1
#> 9 Repeat2 Fold1 3 <lm> Preprocessor1_Model1
#> 10 Repeat2 Fold1 4 <lm> Preprocessor2_Model1
#> # ℹ 14 more rows
相关用法
- R tune coord_obs_pred 对观察值与预测值的绘图使用相同的比例
- R tune conf_mat_resampled 计算重采样的平均混淆矩阵
- R tune extract-tune 提取调整对象的元素
- R tune filter_parameters 删除一些调整参数结果
- R tune fit_best 将模型拟合到数值最优配置
- R tune finalize_model 将最终参数拼接到对象中
- R tune tune_bayes 模型参数的贝叶斯优化。
- R tune show_best 研究最佳调整参数
- R tune expo_decay 指数衰减函数
- R tune fit_resamples 通过重采样拟合多个模型
- R tune merge.recipe 将参数网格值合并到对象中
- R tune autoplot.tune_results 绘图调整搜索结果
- R tune tune_grid 通过网格搜索进行模型调整
- R tune dot-use_case_weights_with_yardstick 确定案例权重是否应传递至标准
- R tune message_wrap 写一条尊重线宽的消息
- R tune prob_improve 用于对参数组合进行评分的获取函数
- R tune last_fit 将最终的最佳模型拟合到训练集并评估测试集
- R update_PACKAGES 更新现有的 PACKAGES 文件
- R textrecipes tokenlist 创建令牌对象
- R themis smotenc SMOTENC算法
- R print.via.format 打印实用程序
- R tibble tibble 构建 DataFrame 架
- R tidyr separate_rows 将折叠的列分成多行
- R textrecipes step_lemma 标记变量的词形还原
- R textrecipes show_tokens 显示配方的令牌输出
注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Obtain and format results produced by tuning functions。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。