augment()
將為給定數據添加預測列。
參數
- x
-
由
fit.model_spec()
或fit_xy.model_spec()
生成的model_fit
對象。 - new_data
-
DataFrame 或矩陣。
- eval_time
-
對於審查回歸模型,估計生存概率的時間點向量。
- ...
-
目前未使用。
細節
回歸
對於回歸模型,添加 .pred
列。如果 x
是使用 fit.model_spec()
創建的,並且 new_data
包含回歸結果列,則還會添加 .resid
列。
刪失回歸
對於這些模型,將創建對預期時間和生存概率的預測(如果模型引擎支持它們)。如果模型支持生存預測,則需要 eval_time
參數。
如果創建了生存預測並且new_data
包含一個survival::Surv()
對象,還添加了額外的列以創建審查權重的逆概率(IPCW)(請參閱tidymodels.org
以下參考文獻中的頁麵)。這使得用戶能夠計算性能指標尺度包。
例子
car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]
reg_form <-
linear_reg() %>%
set_engine("lm") %>%
fit(mpg ~ ., data = car_trn)
reg_xy <-
linear_reg() %>%
set_engine("lm") %>%
fit_xy(car_trn[, -1], car_trn$mpg)
augment(reg_form, car_tst)
#> # A tibble: 10 × 13
#> .pred .resid mpg cyl disp hp drat wt qsec vs am
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 -2.43 21 6 160 110 3.9 2.62 16.5 0 1
#> 2 23.3 -2.30 21 6 160 110 3.9 2.88 17.0 0 1
#> 3 27.6 -4.83 22.8 4 108 93 3.85 2.32 18.6 1 1
#> 4 21.5 -0.147 21.4 6 258 110 3.08 3.22 19.4 1 0
#> 5 17.6 1.13 18.7 8 360 175 3.15 3.44 17.0 0 0
#> 6 21.6 -3.48 18.1 6 225 105 2.76 3.46 20.2 1 0
#> 7 13.9 0.393 14.3 8 360 245 3.21 3.57 15.8 0 0
#> 8 21.7 2.70 24.4 4 147. 62 3.69 3.19 20 1 0
#> 9 25.6 -2.81 22.8 4 141. 95 3.92 3.15 22.9 1 0
#> 10 17.1 2.09 19.2 6 168. 123 3.92 3.44 18.3 1 0
#> # ℹ 2 more variables: gear <dbl>, carb <dbl>
augment(reg_form, car_tst[, -1])
#> # A tibble: 10 × 11
#> .pred cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 23.3 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 27.6 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.5 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 17.6 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 21.6 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 13.9 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 21.7 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 25.6 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 17.1 6 168. 123 3.92 3.44 18.3 1 0 4 4
augment(reg_xy, car_tst)
#> # A tibble: 10 × 12
#> .pred mpg cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 21 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 23.3 21 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 27.6 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.5 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 17.6 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 21.6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 13.9 14.3 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 21.7 24.4 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 25.6 22.8 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 17.1 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4
augment(reg_xy, car_tst[, -1])
#> # A tibble: 10 × 11
#> .pred cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 23.3 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 27.6 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.5 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 17.6 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 21.6 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 13.9 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 21.7 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 25.6 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 17.1 6 168. 123 3.92 3.44 18.3 1 0 4 4
# ------------------------------------------------------------------------------
data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[ 1:10 , ]
cls_form <-
logistic_reg() %>%
set_engine("glm") %>%
fit(Class ~ ., data = cls_trn)
cls_xy <-
logistic_reg() %>%
set_engine("glm") %>%
fit_xy(cls_trn[, -3],
cls_trn$Class)
augment(cls_form, cls_tst)
#> # A tibble: 10 × 6
#> .pred_class .pred_Class1 .pred_Class2 A B Class
#> <fct> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 Class1 0.518 0.482 2.07 1.63 Class1
#> 2 Class1 0.909 0.0913 2.02 1.04 Class1
#> 3 Class1 0.648 0.352 1.69 1.37 Class2
#> 4 Class1 0.610 0.390 3.43 1.98 Class2
#> 5 Class2 0.443 0.557 2.88 1.98 Class1
#> 6 Class2 0.206 0.794 3.31 2.41 Class2
#> 7 Class1 0.708 0.292 2.50 1.56 Class2
#> 8 Class1 0.567 0.433 1.98 1.55 Class2
#> 9 Class1 0.994 0.00582 2.88 0.580 Class1
#> 10 Class2 0.108 0.892 3.74 2.74 Class2
augment(cls_form, cls_tst[, -3])
#> # A tibble: 10 × 5
#> .pred_class .pred_Class1 .pred_Class2 A B
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 Class1 0.518 0.482 2.07 1.63
#> 2 Class1 0.909 0.0913 2.02 1.04
#> 3 Class1 0.648 0.352 1.69 1.37
#> 4 Class1 0.610 0.390 3.43 1.98
#> 5 Class2 0.443 0.557 2.88 1.98
#> 6 Class2 0.206 0.794 3.31 2.41
#> 7 Class1 0.708 0.292 2.50 1.56
#> 8 Class1 0.567 0.433 1.98 1.55
#> 9 Class1 0.994 0.00582 2.88 0.580
#> 10 Class2 0.108 0.892 3.74 2.74
augment(cls_xy, cls_tst)
#> # A tibble: 10 × 6
#> .pred_class .pred_Class1 .pred_Class2 A B Class
#> <fct> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 Class1 0.518 0.482 2.07 1.63 Class1
#> 2 Class1 0.909 0.0913 2.02 1.04 Class1
#> 3 Class1 0.648 0.352 1.69 1.37 Class2
#> 4 Class1 0.610 0.390 3.43 1.98 Class2
#> 5 Class2 0.443 0.557 2.88 1.98 Class1
#> 6 Class2 0.206 0.794 3.31 2.41 Class2
#> 7 Class1 0.708 0.292 2.50 1.56 Class2
#> 8 Class1 0.567 0.433 1.98 1.55 Class2
#> 9 Class1 0.994 0.00582 2.88 0.580 Class1
#> 10 Class2 0.108 0.892 3.74 2.74 Class2
augment(cls_xy, cls_tst[, -3])
#> # A tibble: 10 × 5
#> .pred_class .pred_Class1 .pred_Class2 A B
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 Class1 0.518 0.482 2.07 1.63
#> 2 Class1 0.909 0.0913 2.02 1.04
#> 3 Class1 0.648 0.352 1.69 1.37
#> 4 Class1 0.610 0.390 3.43 1.98
#> 5 Class2 0.443 0.557 2.88 1.98
#> 6 Class2 0.206 0.794 3.31 2.41
#> 7 Class1 0.708 0.292 2.50 1.56
#> 8 Class1 0.567 0.433 1.98 1.55
#> 9 Class1 0.994 0.00582 2.88 0.580
#> 10 Class2 0.108 0.892 3.74 2.74
相關用法
- R parsnip add_rowindex 將一列行號添加到 DataFrame
- R parsnip logistic_reg 邏輯回歸
- R parsnip predict.model_fit 模型預測
- R parsnip linear_reg 線性回歸
- R parsnip C5_rules C5.0 基於規則的分類模型
- R parsnip set_engine 聲明計算引擎和特定參數
- R parsnip condense_control 將控製對象壓縮為更小的控製對象
- R parsnip control_parsnip 控製擬合函數
- R parsnip repair_call 修複模型調用對象
- R parsnip dot-model_param_name_key 翻譯模型調整參數的名稱
- R parsnip glm_grouped 將數據集中的分組二項式結果與個案權重擬合
- R parsnip rule_fit 規則擬合模型
- R parsnip svm_rbf 徑向基函數支持向量機
- R parsnip set_args 更改模型規範的元素
- R parsnip translate 解決計算引擎的模型規範
- R parsnip max_mtry_formula 根據公式確定 mtry 的最大值。此函數可能會根據公式和數據集限製 mtry 的值。對於生存和/或多變量模型來說,這是一種安全的方法。
- R parsnip svm_linear 線性支持向量機
- R parsnip set_new_model 注冊模型的工具
- R parsnip rand_forest 隨機森林
- R parsnip mlp 單層神經網絡
- R parsnip nearest_neighbor K-最近鄰
- R parsnip parsnip_update 更新型號規格
- R parsnip fit 將模型規範擬合到數據集
- R parsnip boost_tree 增強樹
- R parsnip bart 貝葉斯加性回歸樹 (BART)
注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Augment data with predictions。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。