通過對每個候選成員的評估預測擬合正則化模型來評估數據堆棧,以預測真實結果。
此過程確定模型堆棧的"stacking coefficients"。堆疊係數用於對每個候選者的預測進行加權(由數據堆棧中的唯一列表示),並由 LASSO 模型的 beta 給出,該模型將真實結果與數據堆棧的其餘列中給出的預測進行擬合。
具有非零疊加係數的候選者是模型堆棧成員,需要使用 fit_members()
在完整訓練集(而不僅僅是評估集)上進行訓練。此函數通常在多次調用 add_candidates()
之後使用。
用法
blend_predictions(
data_stack,
penalty = 10^(-6:-1),
mixture = 1,
non_negative = TRUE,
metric = NULL,
control = tune::control_grid(),
times = 25,
...
)
參數
- data_stack
-
data_stack
對象 - penalty
-
成員加權中使用的正則化總量的建議值的數值向量。較高的懲罰通常會導致生成的模型堆棧中包含較少的成員,反之亦然。該包將調整由
penalty
和mixture
參數的叉積形成的網格。 - mixture
-
0 到 1(含)之間的數字,給出模型中 L1 正則化(即 lasso)的比例。
mixture = 1
表示純套索模型,mixture = 0
表示嶺回歸,(0, 1)
中的值表示彈性網絡。該包將調整由penalty
和mixture
參數的叉積形成的網格。 - non_negative
-
邏輯給出是否將堆疊係數限製為非負值。如果
TRUE
(默認),則在數據堆棧上擬合模型時將 0 作為lower.limits
參數傳遞給glmnet::glmnet()
。否則,-Inf
。 - metric
-
對
yardstick::metric_set()
的調用。用於調整堆疊係數的套索懲罰的度量。默認值由結果類中的tune::tune_grid()
確定。 - control
-
繼承自
control_grid
的對象,將傳遞給確定堆疊係數的模型。有關可能值的詳細信息,請參閱tune::control_grid()
文檔。請注意,任何extract
條目都將在內部被覆蓋。 - times
-
由確定堆疊係數的模型調整的引導樣本數量。請參閱
rsample::bootstraps()
了解更多信息。 - ...
-
附加參數。目前被忽略。
細節
請注意,正則化線性模型是可用於擬合堆疊集成模型的許多可能的學習算法之一。有關其他集成學習算法的實現,請參閱 h2o::h2o.stackedEnsemble()
和 SuperLearner::SuperLearner()
。
示例數據
該軟件包提供了一些重采樣對象和數據集,用於源自對 1212 個red-eyed 樹蛙胚胎的研究的示例和小插圖!
如果 Red-eyed 樹蛙 (RETF) 胚胎檢測到潛在的捕食者威脅,它們的孵化時間可能會比正常情況下的 7 天更早。研究人員想要確定這些樹蛙胚胎如何以及何時能夠檢測到來自環境的刺激。為此,他們通過用鈍探針搖動胚胎,對不同發育階段的胚胎進行"predator stimulus"測試。盡管一些胚胎事先接受了慶大黴素處理,慶大黴素是一種可以消除側線(感覺器官)的化合物。研究員朱莉·榮格(Julie Jung)和她的團隊發現,這些因子決定了胚胎是否過早孵化!
請注意,stacks 包中包含的數據不一定是完整數據集的代表性或無偏差子集,並且僅用於演示目的。
reg_folds
和 class_folds
是來自 rsample
的 rset
交叉驗證對象,分別將訓練數據分為回歸模型對象和分類模型對象。 tree_frogs_reg_test
和tree_frogs_class_test
是類似的測試集。
reg_res_lr
、reg_res_svm
和 reg_res_sp
分別包含線性回歸、支持向量機和樣條模型的回歸調整結果,擬合 latency
(即胚胎響應抖動需要多長時間孵化)在 tree_frogs
數據中,使用大多數其他變量作為預測變量。請注意,這些模型背後的數據經過過濾,僅包含來自響應刺激而孵化的胚胎的數據。
class_res_rf
和 class_res_nn
分別包含隨機森林和神經網絡分類模型的多類分類調整結果,使用大多數其他變量作為預測變量在數據中擬合 reflex
(耳朵函數的度量)。
log_res_rf
和 log_res_nn
分別包含隨機森林和神經網絡分類模型的二元分類調整結果,使用大多數其他變量擬合 hatched
(無論胚胎是否響應刺激而孵化)預測因子。
請參閱?example_data
了解有關這些對象的更多信息,並瀏覽生成它們的源代碼。
也可以看看
其他核心動詞:add_candidates()
、fit_members()
、stacks()
例子
# see the "Example Data" section above for
# clarification on the objects used in these examples!
# put together a data stack
reg_st <-
stacks() %>%
add_candidates(reg_res_lr) %>%
add_candidates(reg_res_svm) %>%
add_candidates(reg_res_sp)
reg_st
#> # A data stack with 3 model definitions and 16 candidate members:
#> # reg_res_lr: 1 model configuration
#> # reg_res_svm: 5 model configurations
#> # reg_res_sp: 10 model configurations
#> # Outcome: latency (numeric)
# evaluate the data stack
reg_st %>%
blend_predictions()
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 1e-06.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.638
#> 2 reg_res_sp_03_1 linear_reg 0.486
#> 3 reg_res_sp_10_1 linear_reg 0.0482
#>
#> Members have not yet been fitted with `fit_members()`.
# include fewer models by proposing higher penalties
reg_st %>%
blend_predictions(penalty = c(.5, 1))
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 0.5.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.629
#> 2 reg_res_sp_03_1 linear_reg 0.478
#> 3 reg_res_sp_10_1 linear_reg 0.0515
#>
#> Members have not yet been fitted with `fit_members()`.
# allow for negative stacking coefficients
# with the non_negative argument
reg_st %>%
blend_predictions(non_negative = FALSE)
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 12.
#> Penalty: 0.1.
#> Mixture: 1.
#>
#> The 10 highest weighted members are:
#> # A tibble: 10 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_1 svm_rbf -10.5
#> 2 reg_res_sp_04_1 linear_reg -1.38
#> 3 reg_res_sp_05_1 linear_reg 1.35
#> 4 reg_res_svm_1_3 svm_rbf 1.19
#> 5 reg_res_svm_1_2 svm_rbf -0.963
#> 6 reg_res_sp_03_1 linear_reg 0.642
#> 7 reg_res_sp_01_1 linear_reg -0.400
#> 8 reg_res_sp_10_1 linear_reg 0.319
#> 9 reg_res_sp_06_1 linear_reg 0.193
#> 10 reg_res_lr_1_1 linear_reg 0.183
#>
#> Members have not yet been fitted with `fit_members()`.
# use a custom metric in tuning the lasso penalty
library(yardstick)
#> For binary classification, the first factor level is assumed to be the event.
#> Use the argument `event_level = "second"` to alter this as needed.
reg_st %>%
blend_predictions(metric = metric_set(rmse))
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 0.1.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.636
#> 2 reg_res_sp_03_1 linear_reg 0.484
#> 3 reg_res_sp_10_1 linear_reg 0.0496
#>
#> Members have not yet been fitted with `fit_members()`.
# pass control options for stack blending
reg_st %>%
blend_predictions(
control = tune::control_grid(allow_par = TRUE)
)
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 0.1.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.636
#> 2 reg_res_sp_03_1 linear_reg 0.484
#> 3 reg_res_sp_10_1 linear_reg 0.0496
#>
#> Members have not yet been fitted with `fit_members()`.
# to speed up the stacking process for preliminary
# results, bump down the `times` argument:
reg_st %>%
blend_predictions(times = 5)
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 1e-06.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.638
#> 2 reg_res_sp_03_1 linear_reg 0.486
#> 3 reg_res_sp_10_1 linear_reg 0.0482
#>
#> Members have not yet been fitted with `fit_members()`.
# the process looks the same with
# multinomial classification models
class_st <-
stacks() %>%
add_candidates(class_res_nn) %>%
add_candidates(class_res_rf) %>%
blend_predictions()
#> Warning: Predictions from 1 candidate were identical to those from existing
#> candidates and were removed from the data stack.
class_st
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 21 possible candidate members, the ensemble retained 8.
#> Penalty: 0.01.
#> Mixture: 1.
#> Across the 3 classes, there are an average of 4 coefficients per class.
#>
#> The 8 highest weighted member classes are:
#> # A tibble: 8 × 4
#> member type weight class
#> <chr> <chr> <dbl> <fct>
#> 1 .pred_full_class_res_nn_1_1 mlp 23.3 full
#> 2 .pred_mid_class_res_nn_1_1 mlp 1.89 mid
#> 3 .pred_mid_class_res_rf_1_06 rand_forest 1.71 mid
#> 4 .pred_mid_class_res_rf_1_10 rand_forest 1.17 mid
#> 5 .pred_full_class_res_rf_1_03 rand_forest 0.407 full
#> 6 .pred_full_class_res_rf_1_05 rand_forest 0.222 full
#> 7 .pred_full_class_res_rf_1_01 rand_forest 0.00160 full
#> 8 .pred_full_class_res_rf_1_02 rand_forest 0.000322 full
#>
#> Members have not yet been fitted with `fit_members()`.
# ...or binomial classification models
log_st <-
stacks() %>%
add_candidates(log_res_nn) %>%
add_candidates(log_res_rf) %>%
blend_predictions()
log_st
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 11 possible candidate members, the ensemble retained 2.
#> Penalty: 0.01.
#> Mixture: 1.
#>
#> The 2 highest weighted member classes are:
#> # A tibble: 2 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 .pred_no_log_res_nn_1_1 mlp 7.08
#> 2 .pred_no_log_res_rf_1_05 rand_forest 3.10
#>
#> Members have not yet been fitted with `fit_members()`.
相關用法
- R stacks axe_model_stack 砍掉 model_stack。
- R stacks predict.model_stack 使用模型堆棧進行預測
- R stacks add_candidates 將模型定義添加到數據堆棧
- R stacks fit_members 擬合具有非零堆疊係數的模型堆疊成員
- R stacks collect_parameters 收集候選參數和疊加係數
- R stlmethods STL 對象的方法
- R medpolish 矩陣的中值波蘭(穩健雙向分解)
- R naprint 調整缺失值
- R summary.nls 總結非線性最小二乘模型擬合
- R summary.manova 多元方差分析的匯總方法
- R formula 模型公式
- R nls.control 控製 nls 中的迭代
- R aggregate 計算數據子集的匯總統計
- R deriv 簡單表達式的符號和算法導數
- R kruskal.test Kruskal-Wallis 秩和檢驗
- R quade.test 四方測試
- R decompose 移動平均線的經典季節性分解
- R profile-methods stats4 包中的函數配置文件方法
- R plot.stepfun 繪製階躍函數
- R alias 查找模型中的別名(依賴項)
- R qqnorm 分位數-分位數圖
- R update-methods stats4包中函數更新的方法
- R eff.aovlist 多層方差分析的計算效率
- R pairwise.t.test 成對 t 檢驗
- R loglin 擬合對數線性模型
注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Determine stacking coefficients from a data stack。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。