通过对每个候选成员的评估预测拟合正则化模型来评估数据堆栈,以预测真实结果。
此过程确定模型堆栈的"stacking coefficients"。堆叠系数用于对每个候选者的预测进行加权(由数据堆栈中的唯一列表示),并由 LASSO 模型的 beta 给出,该模型将真实结果与数据堆栈的其余列中给出的预测进行拟合。
具有非零叠加系数的候选者是模型堆栈成员,需要使用 fit_members()
在完整训练集(而不仅仅是评估集)上进行训练。此函数通常在多次调用 add_candidates()
之后使用。
用法
blend_predictions(
data_stack,
penalty = 10^(-6:-1),
mixture = 1,
non_negative = TRUE,
metric = NULL,
control = tune::control_grid(),
times = 25,
...
)
参数
- data_stack
-
data_stack
对象 - penalty
-
成员加权中使用的正则化总量的建议值的数值向量。较高的惩罚通常会导致生成的模型堆栈中包含较少的成员,反之亦然。该包将调整由
penalty
和mixture
参数的叉积形成的网格。 - mixture
-
0 到 1(含)之间的数字,给出模型中 L1 正则化(即 lasso)的比例。
mixture = 1
表示纯套索模型,mixture = 0
表示岭回归,(0, 1)
中的值表示弹性网络。该包将调整由penalty
和mixture
参数的叉积形成的网格。 - non_negative
-
逻辑给出是否将堆叠系数限制为非负值。如果
TRUE
(默认),则在数据堆栈上拟合模型时将 0 作为lower.limits
参数传递给glmnet::glmnet()
。否则,-Inf
。 - metric
-
对
yardstick::metric_set()
的调用。用于调整堆叠系数的套索惩罚的度量。默认值由结果类中的tune::tune_grid()
确定。 - control
-
继承自
control_grid
的对象,将传递给确定堆叠系数的模型。有关可能值的详细信息,请参阅tune::control_grid()
文档。请注意,任何extract
条目都将在内部被覆盖。 - times
-
由确定堆叠系数的模型调整的引导样本数量。请参阅
rsample::bootstraps()
了解更多信息。 - ...
-
附加参数。目前被忽略。
细节
请注意,正则化线性模型是可用于拟合堆叠集成模型的许多可能的学习算法之一。有关其他集成学习算法的实现,请参阅 h2o::h2o.stackedEnsemble()
和 SuperLearner::SuperLearner()
。
示例数据
该软件包提供了一些重采样对象和数据集,用于源自对 1212 个red-eyed 树蛙胚胎的研究的示例和小插图!
如果 Red-eyed 树蛙 (RETF) 胚胎检测到潜在的捕食者威胁,它们的孵化时间可能会比正常情况下的 7 天更早。研究人员想要确定这些树蛙胚胎如何以及何时能够检测到来自环境的刺激。为此,他们通过用钝探针摇动胚胎,对不同发育阶段的胚胎进行"predator stimulus"测试。尽管一些胚胎事先接受了庆大霉素处理,庆大霉素是一种可以消除侧线(感觉器官)的化合物。研究员朱莉·荣格(Julie Jung)和她的团队发现,这些因子决定了胚胎是否过早孵化!
请注意,stacks 包中包含的数据不一定是完整数据集的代表性或无偏差子集,并且仅用于演示目的。
reg_folds
和 class_folds
是来自 rsample
的 rset
交叉验证对象,分别将训练数据分为回归模型对象和分类模型对象。 tree_frogs_reg_test
和tree_frogs_class_test
是类似的测试集。
reg_res_lr
、reg_res_svm
和 reg_res_sp
分别包含线性回归、支持向量机和样条模型的回归调整结果,拟合 latency
(即胚胎响应抖动需要多长时间孵化)在 tree_frogs
数据中,使用大多数其他变量作为预测变量。请注意,这些模型背后的数据经过过滤,仅包含来自响应刺激而孵化的胚胎的数据。
class_res_rf
和 class_res_nn
分别包含随机森林和神经网络分类模型的多类分类调整结果,使用大多数其他变量作为预测变量在数据中拟合 reflex
(耳朵函数的度量)。
log_res_rf
和 log_res_nn
分别包含随机森林和神经网络分类模型的二元分类调整结果,使用大多数其他变量拟合 hatched
(无论胚胎是否响应刺激而孵化)预测因子。
请参阅?example_data
了解有关这些对象的更多信息,并浏览生成它们的源代码。
也可以看看
其他核心动词:add_candidates()
、fit_members()
、stacks()
例子
# see the "Example Data" section above for
# clarification on the objects used in these examples!
# put together a data stack
reg_st <-
stacks() %>%
add_candidates(reg_res_lr) %>%
add_candidates(reg_res_svm) %>%
add_candidates(reg_res_sp)
reg_st
#> # A data stack with 3 model definitions and 16 candidate members:
#> # reg_res_lr: 1 model configuration
#> # reg_res_svm: 5 model configurations
#> # reg_res_sp: 10 model configurations
#> # Outcome: latency (numeric)
# evaluate the data stack
reg_st %>%
blend_predictions()
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 1e-06.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.638
#> 2 reg_res_sp_03_1 linear_reg 0.486
#> 3 reg_res_sp_10_1 linear_reg 0.0482
#>
#> Members have not yet been fitted with `fit_members()`.
# include fewer models by proposing higher penalties
reg_st %>%
blend_predictions(penalty = c(.5, 1))
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 0.5.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.629
#> 2 reg_res_sp_03_1 linear_reg 0.478
#> 3 reg_res_sp_10_1 linear_reg 0.0515
#>
#> Members have not yet been fitted with `fit_members()`.
# allow for negative stacking coefficients
# with the non_negative argument
reg_st %>%
blend_predictions(non_negative = FALSE)
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 12.
#> Penalty: 0.1.
#> Mixture: 1.
#>
#> The 10 highest weighted members are:
#> # A tibble: 10 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_1 svm_rbf -10.5
#> 2 reg_res_sp_04_1 linear_reg -1.38
#> 3 reg_res_sp_05_1 linear_reg 1.35
#> 4 reg_res_svm_1_3 svm_rbf 1.19
#> 5 reg_res_svm_1_2 svm_rbf -0.963
#> 6 reg_res_sp_03_1 linear_reg 0.642
#> 7 reg_res_sp_01_1 linear_reg -0.400
#> 8 reg_res_sp_10_1 linear_reg 0.319
#> 9 reg_res_sp_06_1 linear_reg 0.193
#> 10 reg_res_lr_1_1 linear_reg 0.183
#>
#> Members have not yet been fitted with `fit_members()`.
# use a custom metric in tuning the lasso penalty
library(yardstick)
#> For binary classification, the first factor level is assumed to be the event.
#> Use the argument `event_level = "second"` to alter this as needed.
reg_st %>%
blend_predictions(metric = metric_set(rmse))
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 0.1.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.636
#> 2 reg_res_sp_03_1 linear_reg 0.484
#> 3 reg_res_sp_10_1 linear_reg 0.0496
#>
#> Members have not yet been fitted with `fit_members()`.
# pass control options for stack blending
reg_st %>%
blend_predictions(
control = tune::control_grid(allow_par = TRUE)
)
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 0.1.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.636
#> 2 reg_res_sp_03_1 linear_reg 0.484
#> 3 reg_res_sp_10_1 linear_reg 0.0496
#>
#> Members have not yet been fitted with `fit_members()`.
# to speed up the stacking process for preliminary
# results, bump down the `times` argument:
reg_st %>%
blend_predictions(times = 5)
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 16 possible candidate members, the ensemble retained 3.
#> Penalty: 1e-06.
#> Mixture: 1.
#>
#> The 3 highest weighted members are:
#> # A tibble: 3 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 reg_res_svm_1_3 svm_rbf 0.638
#> 2 reg_res_sp_03_1 linear_reg 0.486
#> 3 reg_res_sp_10_1 linear_reg 0.0482
#>
#> Members have not yet been fitted with `fit_members()`.
# the process looks the same with
# multinomial classification models
class_st <-
stacks() %>%
add_candidates(class_res_nn) %>%
add_candidates(class_res_rf) %>%
blend_predictions()
#> Warning: Predictions from 1 candidate were identical to those from existing
#> candidates and were removed from the data stack.
class_st
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 21 possible candidate members, the ensemble retained 8.
#> Penalty: 0.01.
#> Mixture: 1.
#> Across the 3 classes, there are an average of 4 coefficients per class.
#>
#> The 8 highest weighted member classes are:
#> # A tibble: 8 × 4
#> member type weight class
#> <chr> <chr> <dbl> <fct>
#> 1 .pred_full_class_res_nn_1_1 mlp 23.3 full
#> 2 .pred_mid_class_res_nn_1_1 mlp 1.89 mid
#> 3 .pred_mid_class_res_rf_1_06 rand_forest 1.71 mid
#> 4 .pred_mid_class_res_rf_1_10 rand_forest 1.17 mid
#> 5 .pred_full_class_res_rf_1_03 rand_forest 0.407 full
#> 6 .pred_full_class_res_rf_1_05 rand_forest 0.222 full
#> 7 .pred_full_class_res_rf_1_01 rand_forest 0.00160 full
#> 8 .pred_full_class_res_rf_1_02 rand_forest 0.000322 full
#>
#> Members have not yet been fitted with `fit_members()`.
# ...or binomial classification models
log_st <-
stacks() %>%
add_candidates(log_res_nn) %>%
add_candidates(log_res_rf) %>%
blend_predictions()
log_st
#> ── A stacked ensemble model ─────────────────────────────────────
#>
#> Out of 11 possible candidate members, the ensemble retained 2.
#> Penalty: 0.01.
#> Mixture: 1.
#>
#> The 2 highest weighted member classes are:
#> # A tibble: 2 × 3
#> member type weight
#> <chr> <chr> <dbl>
#> 1 .pred_no_log_res_nn_1_1 mlp 7.08
#> 2 .pred_no_log_res_rf_1_05 rand_forest 3.10
#>
#> Members have not yet been fitted with `fit_members()`.
相关用法
- R stacks axe_model_stack 砍掉 model_stack。
- R stacks predict.model_stack 使用模型堆栈进行预测
- R stacks add_candidates 将模型定义添加到数据堆栈
- R stacks fit_members 拟合具有非零堆叠系数的模型堆叠成员
- R stacks collect_parameters 收集候选参数和叠加系数
- R stlmethods STL 对象的方法
- R medpolish 矩阵的中值波兰(稳健双向分解)
- R naprint 调整缺失值
- R summary.nls 总结非线性最小二乘模型拟合
- R summary.manova 多元方差分析的汇总方法
- R formula 模型公式
- R nls.control 控制 nls 中的迭代
- R aggregate 计算数据子集的汇总统计
- R deriv 简单表达式的符号和算法导数
- R kruskal.test Kruskal-Wallis 秩和检验
- R quade.test 四方测试
- R decompose 移动平均线的经典季节性分解
- R profile-methods stats4 包中的函数配置文件方法
- R plot.stepfun 绘制阶跃函数
- R alias 查找模型中的别名(依赖项)
- R qqnorm 分位数-分位数图
- R update-methods stats4包中函数更新的方法
- R eff.aovlist 多层方差分析的计算效率
- R pairwise.t.test 成对 t 检验
- R loglin 拟合对数线性模型
注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Determine stacking coefficients from a data stack。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。