augment()
将为给定数据添加预测列。
参数
- x
-
由
fit.model_spec()
或fit_xy.model_spec()
生成的model_fit
对象。 - new_data
-
DataFrame 或矩阵。
- eval_time
-
对于审查回归模型,估计生存概率的时间点向量。
- ...
-
目前未使用。
细节
回归
对于回归模型,添加 .pred
列。如果 x
是使用 fit.model_spec()
创建的,并且 new_data
包含回归结果列,则还会添加 .resid
列。
删失回归
对于这些模型,将创建对预期时间和生存概率的预测(如果模型引擎支持它们)。如果模型支持生存预测,则需要 eval_time
参数。
如果创建了生存预测并且new_data
包含一个survival::Surv()
对象,还添加了额外的列以创建审查权重的逆概率(IPCW)(请参阅tidymodels.org
以下参考文献中的页面)。这使得用户能够计算性能指标尺度包。
例子
car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]
reg_form <-
linear_reg() %>%
set_engine("lm") %>%
fit(mpg ~ ., data = car_trn)
reg_xy <-
linear_reg() %>%
set_engine("lm") %>%
fit_xy(car_trn[, -1], car_trn$mpg)
augment(reg_form, car_tst)
#> # A tibble: 10 × 13
#> .pred .resid mpg cyl disp hp drat wt qsec vs am
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 -2.43 21 6 160 110 3.9 2.62 16.5 0 1
#> 2 23.3 -2.30 21 6 160 110 3.9 2.88 17.0 0 1
#> 3 27.6 -4.83 22.8 4 108 93 3.85 2.32 18.6 1 1
#> 4 21.5 -0.147 21.4 6 258 110 3.08 3.22 19.4 1 0
#> 5 17.6 1.13 18.7 8 360 175 3.15 3.44 17.0 0 0
#> 6 21.6 -3.48 18.1 6 225 105 2.76 3.46 20.2 1 0
#> 7 13.9 0.393 14.3 8 360 245 3.21 3.57 15.8 0 0
#> 8 21.7 2.70 24.4 4 147. 62 3.69 3.19 20 1 0
#> 9 25.6 -2.81 22.8 4 141. 95 3.92 3.15 22.9 1 0
#> 10 17.1 2.09 19.2 6 168. 123 3.92 3.44 18.3 1 0
#> # ℹ 2 more variables: gear <dbl>, carb <dbl>
augment(reg_form, car_tst[, -1])
#> # A tibble: 10 × 11
#> .pred cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 23.3 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 27.6 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.5 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 17.6 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 21.6 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 13.9 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 21.7 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 25.6 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 17.1 6 168. 123 3.92 3.44 18.3 1 0 4 4
augment(reg_xy, car_tst)
#> # A tibble: 10 × 12
#> .pred mpg cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 21 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 23.3 21 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 27.6 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.5 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 17.6 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 21.6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 13.9 14.3 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 21.7 24.4 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 25.6 22.8 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 17.1 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4
augment(reg_xy, car_tst[, -1])
#> # A tibble: 10 × 11
#> .pred cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 23.3 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 27.6 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.5 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 17.6 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 21.6 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 13.9 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 21.7 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 25.6 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 17.1 6 168. 123 3.92 3.44 18.3 1 0 4 4
# ------------------------------------------------------------------------------
data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[ 1:10 , ]
cls_form <-
logistic_reg() %>%
set_engine("glm") %>%
fit(Class ~ ., data = cls_trn)
cls_xy <-
logistic_reg() %>%
set_engine("glm") %>%
fit_xy(cls_trn[, -3],
cls_trn$Class)
augment(cls_form, cls_tst)
#> # A tibble: 10 × 6
#> .pred_class .pred_Class1 .pred_Class2 A B Class
#> <fct> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 Class1 0.518 0.482 2.07 1.63 Class1
#> 2 Class1 0.909 0.0913 2.02 1.04 Class1
#> 3 Class1 0.648 0.352 1.69 1.37 Class2
#> 4 Class1 0.610 0.390 3.43 1.98 Class2
#> 5 Class2 0.443 0.557 2.88 1.98 Class1
#> 6 Class2 0.206 0.794 3.31 2.41 Class2
#> 7 Class1 0.708 0.292 2.50 1.56 Class2
#> 8 Class1 0.567 0.433 1.98 1.55 Class2
#> 9 Class1 0.994 0.00582 2.88 0.580 Class1
#> 10 Class2 0.108 0.892 3.74 2.74 Class2
augment(cls_form, cls_tst[, -3])
#> # A tibble: 10 × 5
#> .pred_class .pred_Class1 .pred_Class2 A B
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 Class1 0.518 0.482 2.07 1.63
#> 2 Class1 0.909 0.0913 2.02 1.04
#> 3 Class1 0.648 0.352 1.69 1.37
#> 4 Class1 0.610 0.390 3.43 1.98
#> 5 Class2 0.443 0.557 2.88 1.98
#> 6 Class2 0.206 0.794 3.31 2.41
#> 7 Class1 0.708 0.292 2.50 1.56
#> 8 Class1 0.567 0.433 1.98 1.55
#> 9 Class1 0.994 0.00582 2.88 0.580
#> 10 Class2 0.108 0.892 3.74 2.74
augment(cls_xy, cls_tst)
#> # A tibble: 10 × 6
#> .pred_class .pred_Class1 .pred_Class2 A B Class
#> <fct> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 Class1 0.518 0.482 2.07 1.63 Class1
#> 2 Class1 0.909 0.0913 2.02 1.04 Class1
#> 3 Class1 0.648 0.352 1.69 1.37 Class2
#> 4 Class1 0.610 0.390 3.43 1.98 Class2
#> 5 Class2 0.443 0.557 2.88 1.98 Class1
#> 6 Class2 0.206 0.794 3.31 2.41 Class2
#> 7 Class1 0.708 0.292 2.50 1.56 Class2
#> 8 Class1 0.567 0.433 1.98 1.55 Class2
#> 9 Class1 0.994 0.00582 2.88 0.580 Class1
#> 10 Class2 0.108 0.892 3.74 2.74 Class2
augment(cls_xy, cls_tst[, -3])
#> # A tibble: 10 × 5
#> .pred_class .pred_Class1 .pred_Class2 A B
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 Class1 0.518 0.482 2.07 1.63
#> 2 Class1 0.909 0.0913 2.02 1.04
#> 3 Class1 0.648 0.352 1.69 1.37
#> 4 Class1 0.610 0.390 3.43 1.98
#> 5 Class2 0.443 0.557 2.88 1.98
#> 6 Class2 0.206 0.794 3.31 2.41
#> 7 Class1 0.708 0.292 2.50 1.56
#> 8 Class1 0.567 0.433 1.98 1.55
#> 9 Class1 0.994 0.00582 2.88 0.580
#> 10 Class2 0.108 0.892 3.74 2.74
相关用法
- R parsnip add_rowindex 将一列行号添加到 DataFrame
- R parsnip logistic_reg 逻辑回归
- R parsnip predict.model_fit 模型预测
- R parsnip linear_reg 线性回归
- R parsnip C5_rules C5.0 基于规则的分类模型
- R parsnip set_engine 声明计算引擎和特定参数
- R parsnip condense_control 将控制对象压缩为更小的控制对象
- R parsnip control_parsnip 控制拟合函数
- R parsnip repair_call 修复模型调用对象
- R parsnip dot-model_param_name_key 翻译模型调整参数的名称
- R parsnip glm_grouped 将数据集中的分组二项式结果与个案权重拟合
- R parsnip rule_fit 规则拟合模型
- R parsnip svm_rbf 径向基函数支持向量机
- R parsnip set_args 更改模型规范的元素
- R parsnip translate 解决计算引擎的模型规范
- R parsnip max_mtry_formula 根据公式确定 mtry 的最大值。此函数可能会根据公式和数据集限制 mtry 的值。对于生存和/或多变量模型来说,这是一种安全的方法。
- R parsnip svm_linear 线性支持向量机
- R parsnip set_new_model 注册模型的工具
- R parsnip rand_forest 随机森林
- R parsnip mlp 单层神经网络
- R parsnip nearest_neighbor K-最近邻
- R parsnip parsnip_update 更新型号规格
- R parsnip fit 将模型规范拟合到数据集
- R parsnip boost_tree 增强树
- R parsnip bart 贝叶斯加性回归树 (BART)
注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Augment data with predictions。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。