獲取並格式化由調整函數產生的結果
用法
collect_predictions(x, ...)
# S3 method for default
collect_predictions(x, ...)
# S3 method for tune_results
collect_predictions(x, summarize = FALSE, parameters = NULL, ...)
collect_metrics(x, ...)
# S3 method for tune_results
collect_metrics(x, summarize = TRUE, ...)
collect_notes(x, ...)
# S3 method for tune_results
collect_notes(x, ...)
collect_extracts(x, ...)
# S3 method for tune_results
collect_extracts(x, ...)
參數
- x
-
tune_grid()
、tune_bayes()
、fit_resamples()
或last_fit()
的結果。對於collect_predictions()
,應該使用控製選項save_pred = TRUE
。 - ...
-
目前未使用。
- summarize
-
邏輯性強;應該通過重新采樣(
TRUE
)匯總指標,還是返回每個單獨重新采樣的值。請注意,如果x
由last_fit()
創建,則summarize
無效。對於其他對象類型,總結預測的方法詳述如下。 - parameters
-
可選的調整參數值小標題,可用於在處理之前過濾預測值。該小標題應該隻包含每個調整參數標識符的列(例如,如果使用
tune("my_param")
,則為"my_param"
)。
值
一點點。列名稱取決於結果和模型的模式。
對於 collect_metrics()
和 collect_predictions()
,未匯總時,每個調整參數都有列(使用 tune()
中的 id
(如果有))。 collect_metrics()
還具有列 .metric
和 .estimator
。匯總結果時,會出現 mean
、 n
和 std_err
列。未匯總時,重采樣標識符和 .estimate
的附加列。
對於 collect_predictions()
,還有用於重采樣標識符的附加列、用於預測值的列(例如 .pred
、 .pred_class
等)以及用於使用原始列的結果的列數據中的名稱。
collect_predictions()
可以總結重複 out-of-sample 預測的各種結果。例如,當使用引導程序時,原始訓練集中的每一行都有多個保留預測(跨評估集)。為了將這些結果轉換為每個訓練集都具有單個預測值的格式,需要對重複預測的結果進行平均。
對於回歸情況,隻需對數值預測進行平均。對於分類模型來說,問題更加複雜。當使用類別概率時,會對它們進行平均,然後重新標準化以確保它們相加為 1。如果數據中也存在硬類預測,則這些預測是根據匯總的概率估計確定的(以便它們匹配)。如果結果中隻有硬類預測,則使用該模式進行總結。
collect_notes()
返回一個小標題,其中包含重采樣指示器、位置(預處理器、模型等)、類型(錯誤或警告)和注釋的列。
collect_extracts()
返回一個 tibble,其中包含重采樣指示器的列、位置(預處理器、模型等)以及通過 control functions 的 extract
參數從工作流中提取的對象。
例子
data("example_ames_knn")
# The parameters for the model:
extract_parameter_set_dials(ames_wflow)
#> Collection of 5 parameters for tuning
#>
#> identifier type object
#> K neighbors nparam[+]
#> weight_func weight_func dparam[+]
#> dist_power dist_power nparam[+]
#> lon deg_free nparam[+]
#> lat deg_free nparam[+]
#>
# Summarized over resamples
collect_metrics(ames_grid_search)
#> # A tibble: 20 × 11
#> K weight_func dist_power lon lat .metric .estimator mean
#> <int> <chr> <dbl> <int> <int> <chr> <chr> <dbl>
#> 1 35 optimal 1.32 8 1 rmse standard 0.0785
#> 2 35 optimal 1.32 8 1 rsq standard 0.823
#> 3 35 rank 1.29 3 13 rmse standard 0.0809
#> 4 35 rank 1.29 3 13 rsq standard 0.814
#> 5 21 cos 0.626 1 4 rmse standard 0.0746
#> 6 21 cos 0.626 1 4 rsq standard 0.836
#> 7 4 biweight 0.311 8 4 rmse standard 0.0777
#> 8 4 biweight 0.311 8 4 rsq standard 0.814
#> 9 32 triangular 0.165 9 15 rmse standard 0.0770
#> 10 32 triangular 0.165 9 15 rsq standard 0.826
#> 11 3 rank 1.86 10 15 rmse standard 0.0875
#> 12 3 rank 1.86 10 15 rsq standard 0.762
#> 13 40 triangular 0.167 11 7 rmse standard 0.0778
#> 14 40 triangular 0.167 11 7 rsq standard 0.822
#> 15 12 epanechnikov 1.53 4 7 rmse standard 0.0774
#> 16 12 epanechnikov 1.53 4 7 rsq standard 0.820
#> 17 5 rank 0.411 2 7 rmse standard 0.0740
#> 18 5 rank 0.411 2 7 rsq standard 0.833
#> 19 33 triweight 0.511 10 3 rmse standard 0.0728
#> 20 33 triweight 0.511 10 3 rsq standard 0.842
#> # ℹ 3 more variables: n <int>, std_err <dbl>, .config <chr>
# Per-resample values
collect_metrics(ames_grid_search, summarize = FALSE)
#> # A tibble: 200 × 10
#> id K weight_func dist_power lon lat .metric .estimator
#> <chr> <int> <chr> <dbl> <int> <int> <chr> <chr>
#> 1 Fold01 35 optimal 1.32 8 1 rmse standard
#> 2 Fold01 35 optimal 1.32 8 1 rsq standard
#> 3 Fold02 35 optimal 1.32 8 1 rmse standard
#> 4 Fold02 35 optimal 1.32 8 1 rsq standard
#> 5 Fold03 35 optimal 1.32 8 1 rmse standard
#> 6 Fold03 35 optimal 1.32 8 1 rsq standard
#> 7 Fold04 35 optimal 1.32 8 1 rmse standard
#> 8 Fold04 35 optimal 1.32 8 1 rsq standard
#> 9 Fold05 35 optimal 1.32 8 1 rmse standard
#> 10 Fold05 35 optimal 1.32 8 1 rsq standard
#> # ℹ 190 more rows
#> # ℹ 2 more variables: .estimate <dbl>, .config <chr>
# ---------------------------------------------------------------------------
library(parsnip)
library(rsample)
library(dplyr)
#>
#> Attaching package: ‘dplyr’
#> The following objects are masked from ‘package:stats’:
#>
#> filter, lag
#> The following objects are masked from ‘package:base’:
#>
#> intersect, setdiff, setequal, union
library(recipes)
#>
#> Attaching package: ‘recipes’
#> The following object is masked from ‘package:stats’:
#>
#> step
library(tibble)
lm_mod <- linear_reg() %>% set_engine("lm")
set.seed(93599150)
car_folds <- vfold_cv(mtcars, v = 2, repeats = 3)
ctrl <- control_resamples(save_pred = TRUE, extract = extract_fit_engine)
spline_rec <-
recipe(mpg ~ ., data = mtcars) %>%
step_ns(disp, deg_free = tune("df"))
grid <- tibble(df = 3:6)
resampled <-
lm_mod %>%
tune_grid(spline_rec, resamples = car_folds, control = ctrl, grid = grid)
collect_predictions(resampled) %>% arrange(.row)
#> # A tibble: 384 × 7
#> id id2 .pred .row df mpg .config
#> <chr> <chr> <dbl> <int> <int> <dbl> <chr>
#> 1 Repeat1 Fold2 16.5 1 3 21 Preprocessor1_Model1
#> 2 Repeat2 Fold1 19.0 1 3 21 Preprocessor1_Model1
#> 3 Repeat3 Fold1 20.0 1 3 21 Preprocessor1_Model1
#> 4 Repeat1 Fold2 15.1 1 4 21 Preprocessor2_Model1
#> 5 Repeat2 Fold1 17.7 1 4 21 Preprocessor2_Model1
#> 6 Repeat3 Fold1 20.1 1 4 21 Preprocessor2_Model1
#> 7 Repeat1 Fold2 17.9 1 5 21 Preprocessor3_Model1
#> 8 Repeat2 Fold1 18.3 1 5 21 Preprocessor3_Model1
#> 9 Repeat3 Fold1 20.4 1 5 21 Preprocessor3_Model1
#> 10 Repeat1 Fold2 15.1 1 6 21 Preprocessor4_Model1
#> # ℹ 374 more rows
collect_predictions(resampled, summarize = TRUE) %>% arrange(.row)
#> # A tibble: 128 × 5
#> .row df mpg .config .pred
#> <int> <int> <dbl> <chr> <dbl>
#> 1 1 3 21 Preprocessor1_Model1 18.5
#> 2 1 4 21 Preprocessor2_Model1 17.6
#> 3 1 5 21 Preprocessor3_Model1 18.9
#> 4 1 6 21 Preprocessor4_Model1 16.7
#> 5 2 3 21 Preprocessor1_Model1 19.4
#> 6 2 4 21 Preprocessor2_Model1 19.0
#> 7 2 5 21 Preprocessor3_Model1 18.7
#> 8 2 6 21 Preprocessor4_Model1 16.4
#> 9 3 3 22.8 Preprocessor1_Model1 31.8
#> 10 3 4 22.8 Preprocessor2_Model1 23.8
#> # ℹ 118 more rows
collect_predictions(resampled, summarize = TRUE, grid[1, ]) %>% arrange(.row)
#> # A tibble: 32 × 5
#> .row df mpg .config .pred
#> <int> <int> <dbl> <chr> <dbl>
#> 1 1 3 21 Preprocessor1_Model1 18.5
#> 2 2 3 21 Preprocessor1_Model1 19.4
#> 3 3 3 22.8 Preprocessor1_Model1 31.8
#> 4 4 3 21.4 Preprocessor1_Model1 20.2
#> 5 5 3 18.7 Preprocessor1_Model1 18.4
#> 6 6 3 18.1 Preprocessor1_Model1 20.6
#> 7 7 3 14.3 Preprocessor1_Model1 13.5
#> 8 8 3 24.4 Preprocessor1_Model1 19.2
#> 9 9 3 22.8 Preprocessor1_Model1 34.8
#> 10 10 3 19.2 Preprocessor1_Model1 16.6
#> # ℹ 22 more rows
collect_extracts(resampled)
#> # A tibble: 24 × 5
#> id id2 df .extracts .config
#> <chr> <chr> <int> <list> <chr>
#> 1 Repeat1 Fold1 3 <lm> Preprocessor1_Model1
#> 2 Repeat1 Fold1 4 <lm> Preprocessor2_Model1
#> 3 Repeat1 Fold1 5 <lm> Preprocessor3_Model1
#> 4 Repeat1 Fold1 6 <lm> Preprocessor4_Model1
#> 5 Repeat1 Fold2 3 <lm> Preprocessor1_Model1
#> 6 Repeat1 Fold2 4 <lm> Preprocessor2_Model1
#> 7 Repeat1 Fold2 5 <lm> Preprocessor3_Model1
#> 8 Repeat1 Fold2 6 <lm> Preprocessor4_Model1
#> 9 Repeat2 Fold1 3 <lm> Preprocessor1_Model1
#> 10 Repeat2 Fold1 4 <lm> Preprocessor2_Model1
#> # ℹ 14 more rows
相關用法
- R tune coord_obs_pred 對觀察值與預測值的繪圖使用相同的比例
- R tune conf_mat_resampled 計算重采樣的平均混淆矩陣
- R tune extract-tune 提取調整對象的元素
- R tune filter_parameters 刪除一些調整參數結果
- R tune fit_best 將模型擬合到數值最優配置
- R tune finalize_model 將最終參數拚接到對象中
- R tune tune_bayes 模型參數的貝葉斯優化。
- R tune show_best 研究最佳調整參數
- R tune expo_decay 指數衰減函數
- R tune fit_resamples 通過重采樣擬合多個模型
- R tune merge.recipe 將參數網格值合並到對象中
- R tune autoplot.tune_results 繪圖調整搜索結果
- R tune tune_grid 通過網格搜索進行模型調整
- R tune dot-use_case_weights_with_yardstick 確定案例權重是否應傳遞至標準
- R tune message_wrap 寫一條尊重線寬的消息
- R tune prob_improve 用於對參數組合進行評分的獲取函數
- R tune last_fit 將最終的最佳模型擬合到訓練集並評估測試集
- R update_PACKAGES 更新現有的 PACKAGES 文件
- R textrecipes tokenlist 創建令牌對象
- R themis smotenc SMOTENC算法
- R print.via.format 打印實用程序
- R tibble tibble 構建 DataFrame 架
- R tidyr separate_rows 將折疊的列分成多行
- R textrecipes step_lemma 標記變量的詞形還原
- R textrecipes show_tokens 顯示配方的令牌輸出
注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Obtain and format results produced by tuning functions。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。