当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R stacks predict.model_stack 使用模型堆栈进行预测


应用模型堆栈来创建不同类型的预测。

用法

# S3 method for model_stack
predict(object, new_data, type = NULL, members = FALSE, opts = list(), ...)

参数

object

fit_members() 输出的具有拟合成员的模型堆栈。

new_data

矩形数据对象,例如 DataFrame 。

type

返回预测值的格式 — "numeric"、"class" 或 "prob" 之一。当为NULL时,predict()将根据模型的模式选择合适的值。

members

逻辑性强。是否另外返回每个集合成员的预测。

opts

传递给每个成员的 parsnip::predict.model_fit 的基础预测函数的可选参数列表。

...

附加参数。目前被忽略。

示例数据

该软件包提供了一些重采样对象和数据集,用于源自对 1212 个red-eyed 树蛙胚胎的研究的示例和小插图!

如果 Red-eyed 树蛙 (RETF) 胚胎检测到潜在的捕食者威胁,它们的孵化时间可能会比正常情况下的 7 天更早。研究人员想要确定这些树蛙胚胎如何以及何时能够检测到来自环境的刺激。为此,他们通过用钝探针摇动胚胎,对不同发育阶段的胚胎进行"predator stimulus"测试。尽管一些胚胎事先接受了庆大霉素处理,庆大霉素是一种可以消除侧线(感觉器官)的化合物。研究员朱莉·荣格(Julie Jung)和她的团队发现,这些因子决定了胚胎是否过早孵化!

请注意,stacks 包中包含的数据不一定是完整数据集的代表性或无偏差子集,并且仅用于演示目的。

reg_foldsclass_folds 是来自 rsamplerset 交叉验证对象,分别将训练数据分为回归模型对象和分类模型对象。 tree_frogs_reg_testtree_frogs_class_test 是类似的测试集。

reg_res_lrreg_res_svmreg_res_sp 分别包含线性回归、支持向量机和样条模型的回归调整结果,拟合 latency(即胚胎响应抖动需要多长时间孵化)在 tree_frogs 数据中,使用大多数其他变量作为预测变量。请注意,这些模型背后的数据经过过滤,仅包含来自响应刺激而孵化的胚胎的数据。

class_res_rfclass_res_nn 分别包含随机森林和神经网络分类模型的多类分类调整结果,使用大多数其他变量作为预测变量在数据中拟合 reflex(耳朵函数的度量)。

log_res_rflog_res_nn 分别包含随机森林和神经网络分类模型的二元分类调整结果,使用大多数其他变量拟合 hatched(无论胚胎是否响应刺激而孵化)预测因子。

请参阅?example_data 了解有关这些对象的更多信息,并浏览生成它们的源代码。

例子


# see the "Example Data" section above for
# clarification on the data and tuning results
# objects used in these examples!

data(tree_frogs_reg_test)
data(tree_frogs_class_test)

# build and fit a regression model stack
reg_st <-
  stacks() %>%
  add_candidates(reg_res_lr) %>%
  add_candidates(reg_res_sp) %>%
  blend_predictions() %>%
  fit_members()

reg_st
#> ── A stacked ensemble model ─────────────────────────────────────
#> 
#> Out of 11 possible candidate members, the ensemble retained 4.
#> Penalty: 1e-06.
#> Mixture: 1.
#> 
#> The 4 highest weighted members are:
#> # A tibble: 4 × 3
#>   member          type       weight
#>   <chr>           <chr>       <dbl>
#> 1 reg_res_sp_03_1 linear_reg 0.485 
#> 2 reg_res_sp_10_1 linear_reg 0.247 
#> 3 reg_res_lr_1_1  linear_reg 0.129 
#> 4 reg_res_sp_05_1 linear_reg 0.0666

# predict on the tree frogs testing data
predict(reg_st, tree_frogs_reg_test)
#> # A tibble: 143 × 1
#>    .pred
#>    <dbl>
#>  1 119. 
#>  2  81.4
#>  3 102. 
#>  4  35.5
#>  5 119. 
#>  6  50.5
#>  7 122. 
#>  8  82.7
#>  9  50.2
#> 10  75.7
#> # ℹ 133 more rows

# include the predictions from the members
predict(reg_st, tree_frogs_reg_test, members = TRUE)
#> # A tibble: 143 × 5
#>    .pred reg_res_lr_1_1 reg_res_sp_10_1 reg_res_sp_05_1 reg_res_sp_03_1
#>    <dbl>          <dbl>           <dbl>           <dbl>           <dbl>
#>  1 119.           138.            125.            121.            114. 
#>  2  81.4           82.4            84.8            81.8            77.1
#>  3 102.           116.            111.            112.             93.3
#>  4  35.5           35.8            29.7            32.5            29.6
#>  5 119.           111.            115.            115.            127. 
#>  6  50.5           38.8            37.4            36.2            55.3
#>  7 122.           123.            103.            104.            137. 
#>  8  82.7           82.3            78.6            82.0            82.8
#>  9  50.2           38.7            37.3            36.2            54.8
#> 10  75.7           78.8            75.3            76.9            71.8
#> # ℹ 133 more rows

# build and fit a classification model stack
class_st <-
  stacks() %>%
  add_candidates(class_res_nn) %>%
  add_candidates(class_res_rf) %>%
  blend_predictions() %>%
  fit_members()
#> Warning: Predictions from 1 candidate were identical to those from existing
#> candidates and were removed from the data stack.
 
class_st
#> ── A stacked ensemble model ─────────────────────────────────────
#> 
#> Out of 21 possible candidate members, the ensemble retained 5.
#> Penalty: 0.1.
#> Mixture: 1.
#> Across the 3 classes, there are an average of 2.5 coefficients per class.
#> 
#> The 5 highest weighted member classes are:
#> # A tibble: 5 × 4
#>   member                       type          weight class
#>   <chr>                        <chr>          <dbl> <fct>
#> 1 .pred_full_class_res_nn_1_1  mlp         12.0     full 
#> 2 .pred_mid_class_res_rf_1_06  rand_forest  0.670   mid  
#> 3 .pred_full_class_res_rf_1_05 rand_forest  0.101   full 
#> 4 .pred_full_class_res_rf_1_07 rand_forest  0.00457 full 
#> 5 .pred_full_class_res_rf_1_01 rand_forest  0.00219 full 

# predict reflex, first as a class, then as
# class probabilities
predict(class_st, tree_frogs_class_test)
#> # A tibble: 303 × 1
#>    .pred_class
#>    <fct>      
#>  1 full       
#>  2 mid        
#>  3 mid        
#>  4 mid        
#>  5 full       
#>  6 full       
#>  7 full       
#>  8 full       
#>  9 full       
#> 10 full       
#> # ℹ 293 more rows
predict(class_st, tree_frogs_class_test, type = "prob")
#> # A tibble: 303 × 3
#>    .pred_full .pred_low .pred_mid
#>         <dbl>     <dbl>     <dbl>
#>  1     0.909     0.0540    0.0371
#>  2     0.0992    0.435     0.466 
#>  3     0.0956    0.418     0.486 
#>  4     0.104     0.432     0.464 
#>  5     0.908     0.0542    0.0373
#>  6     0.909     0.0540    0.0372
#>  7     0.909     0.0540    0.0372
#>  8     0.909     0.0540    0.0371
#>  9     0.909     0.0540    0.0371
#> 10     0.909     0.0540    0.0371
#> # ℹ 293 more rows

# returning the member predictions as well
predict(
  class_st, 
  tree_frogs_class_test, 
  type = "prob", 
  members = TRUE
)
#> # A tibble: 303 × 18
#>    .pred_full .pred_low .pred_mid .pred_low_class_res_rf_1_06
#>         <dbl>     <dbl>     <dbl>                       <dbl>
#>  1     0.909     0.0540    0.0371                       0    
#>  2     0.0992    0.435     0.466                        0.337
#>  3     0.0956    0.418     0.486                        0.216
#>  4     0.104     0.432     0.464                        0.334
#>  5     0.908     0.0542    0.0373                       0    
#>  6     0.909     0.0540    0.0372                       0    
#>  7     0.909     0.0540    0.0372                       0    
#>  8     0.909     0.0540    0.0371                       0    
#>  9     0.909     0.0540    0.0371                       0    
#> 10     0.909     0.0540    0.0371                       0    
#> # ℹ 293 more rows
#> # ℹ 14 more variables: .pred_low_class_res_nn_1_1 <dbl>,
#> #   .pred_low_class_res_rf_1_05 <dbl>, .pred_low_class_res_rf_1_01 <dbl>,
#> #   .pred_low_class_res_rf_1_07 <dbl>, .pred_mid_class_res_rf_1_06 <dbl>,
#> #   .pred_mid_class_res_nn_1_1 <dbl>, .pred_mid_class_res_rf_1_05 <dbl>,
#> #   .pred_mid_class_res_rf_1_01 <dbl>, .pred_mid_class_res_rf_1_07 <dbl>,
#> #   .pred_full_class_res_rf_1_06 <dbl>, …
源代码:R/predict.R

相关用法


注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Predicting with a model stack。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。