應用模型來創建不同類型的預測。 predict()
可用於所有類型的模型,並使用 "type" 參數以獲得更多特異性。
用法
# S3 method for model_fit
predict(object, new_data, type = NULL, opts = list(), ...)
# S3 method for model_fit
predict_raw(object, new_data, opts = list(), ...)
predict_raw(object, ...)
參數
- object
-
類
model_fit
的對象。 - new_data
-
矩形數據對象,例如 DataFrame 。
- type
-
單個字符值或
NULL
。可能的值為"numeric"
,"class"
,"prob"
,"conf_int"
,"pred_int"
,"quantile"
,"time"
,"hazard"
,"survival"
或"raw"
。當NULL
時,predict()
會根據模型的模式選擇合適的值。 - opts
-
type = "raw"
時將使用的基礎預測函數的可選參數列表。該列表不應包含模型對象或正在預測的新數據的選項。 - ...
-
其他
parsnip
相關選項,具體取決於type
的值。無法在此處傳遞底層模型預測函數的參數(請改用opts
參數)。可能的論點是:-
interval
:對於type
等於"survival"
或"quantile"
,是否應該添加間隔估計(如果有)?選項是"none"
和"confidence"
。 -
level
:對於type
等於"conf_int"
、"pred_int"
或"survival"
,這是間隔尾部區域的參數(例如置信區間的置信水平)。默認值為0.95
。 -
std_error
:對於type
等於"conf_int"
或"pred_int"
,添加擬合或預測的標準誤差(在線性預測變量的範圍內)。默認值為FALSE
。 -
quantile
:對於type
等於quantile
,分布的分位數。默認為(1:9)/10
。 -
eval_time
:對於type
等於"survival"
或"hazard"
,估計生存概率或危險的時間點。
-
值
除了 type = "raw"
之外,predict.model_fit()
的結果
-
是一個小詞
-
行數與
new_data
中的行數一樣多 -
具有標準化的列名稱,如下所示:
對於 type = "numeric"
,tibble 具有用於單個結果的 .pred
列和用於多變量結果的 .pred_Yname
列。
對於 type = "class"
,tibble 有一個 .pred_class
列。
對於 type = "prob"
,標題具有 .pred_classlevel
列。
對於 type = "conf_int"
和 type = "pred_int"
,tibble 具有帶有置信度屬性的 .pred_lower
和 .pred_upper
列。在可以為類概率(或其他非標量輸出)生成間隔的情況下,列被命名為 .pred_lower_classlevel
等。
對於 type = "quantile"
,tibble 有一個 .pred
列,它是一個列表列。每個列表元素都包含一個帶有列 .pred
和 .quantile
(可能還有其他列)的 tibble。
對於 type = "time"
,tibble 有一個 .pred_time
列。
對於 type = "survival"
,tibble 有一個 .pred
列,它是一個列表列。每個列表元素都包含一個帶有列 .eval_time
和 .pred_survival
(可能還有其他列)的 tibble。
對於 type = "hazard"
,tibble 有一個 .pred
列,它是一個列表列。每個列表元素都包含一個帶有列 .eval_time
和 .pred_hazard
(可能還有其他列)的 tibble。
將 type = "raw"
與 predict.model_fit()
結合使用將返回預測函數的純正結果。
對於基於 Spark 的模型,由於表列不能包含點,因此使用相同的約定,除了 1) 名稱中不出現點和 2) 永遠不會返回向量,但返回 type-specific 預測函數。
當模型擬合失敗並捕獲錯誤時,predict()
函數將返回與上述相同的結構,但填充缺失值。目前這不適用於多變量模型。
細節
對於 type = NULL
, predict()
使用
-
type = "numeric"
用於回歸模型, -
type = "class"
用於分類,以及 -
type = "time"
用於審查回歸。
刪失回歸預測
對於審查回歸,當請求生存或危險概率時,需要 eval_time
的數值向量。時間值必須是唯一的、有限的、非缺失的且非負的。 predict()
函數將通過刪除違規點(帶有警告)來調整值以適應此規範。
predict.model_fit()
不要求存在結果。對於預測生存概率的性能指標,需要審查權重的逆概率 (IPCW)(請參閱下麵的 tidymodels.org
參考)。這些需要結果,因此 predict()
不會返回。如果 new_data
包含結果為 Surv
對象的列,則可以通過 augment.model_fit()
添加它們。
此外,當 type = "linear_pred"
時,截尾回歸模型將默認格式化,使得線性預測變量隨時間增加。這可能與底層模型的 predict()
方法產生的符號相反。設置increasing = FALSE
以抑製此行為。
例子
library(dplyr)
lm_model <-
linear_reg() %>%
set_engine("lm") %>%
fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32))
pred_cars <-
mtcars %>%
dplyr::slice(1:10) %>%
dplyr::select(-mpg)
predict(lm_model, pred_cars)
#> # A tibble: 10 × 1
#> .pred
#> <dbl>
#> 1 23.4
#> 2 23.3
#> 3 27.6
#> 4 21.5
#> 5 17.6
#> 6 21.6
#> 7 13.9
#> 8 21.7
#> 9 25.6
#> 10 17.1
predict(
lm_model,
pred_cars,
type = "conf_int",
level = 0.90
)
#> # A tibble: 10 × 2
#> .pred_lower .pred_upper
#> <dbl> <dbl>
#> 1 17.9 29.0
#> 2 18.1 28.5
#> 3 24.0 31.3
#> 4 17.5 25.6
#> 5 14.3 20.8
#> 6 17.0 26.2
#> 7 9.65 18.2
#> 8 16.2 27.2
#> 9 14.2 37.0
#> 10 11.5 22.7
predict(
lm_model,
pred_cars,
type = "raw",
opts = list(type = "terms")
)
#> cyl disp hp drat
#> Mazda RX4 -0.001433177 -0.8113275 0.6303467 -0.06120265
#> Mazda RX4 Wag -0.001433177 -0.8113275 0.6303467 -0.06120265
#> Datsun 710 -0.009315653 -1.3336453 0.8557288 -0.05014798
#> Hornet 4 Drive -0.001433177 0.1730406 0.6303467 0.12009386
#> Hornet Sportabout 0.006449298 1.1975870 -0.2314083 0.10461733
#> Valiant -0.001433177 -0.1584303 0.6966356 0.19084372
#> Duster 360 0.006449298 1.1975870 -1.1594522 0.09135173
#> Merc 240D -0.009315653 -0.9449204 1.2667197 -0.01477305
#> Merc 230 -0.009315653 -1.0041833 0.8292133 -0.06562451
#> Merc 280 -0.001433177 -0.7349888 0.4579957 -0.06562451
#> wt qsec vs am gear
#> Mazda RX4 2.4139815 -1.567729 0.2006406 2.88774 0.02512680
#> Mazda RX4 Wag 1.4488706 -0.736286 0.2006406 2.88774 0.02512680
#> Datsun 710 3.5494061 1.624418 -0.3511210 2.88774 0.02512680
#> Hornet 4 Drive 0.1620561 2.856736 -0.3511210 -2.40645 -0.06700481
#> Hornet Sportabout -0.6895124 -0.736286 0.2006406 -2.40645 -0.06700481
#> Valiant -0.7652074 4.014817 -0.3511210 -2.40645 -0.06700481
#> Duster 360 -1.1815297 -2.488255 0.2006406 -2.40645 -0.06700481
#> Merc 240D 0.2566748 3.688179 -0.3511210 -2.40645 0.02512680
#> Merc 230 0.4080647 7.993866 -0.3511210 -2.40645 0.02512680
#> Merc 280 -0.6895124 1.164155 -0.3511210 -2.40645 0.02512680
#> carb
#> Mazda RX4 -0.2497240
#> Mazda RX4 Wag -0.2497240
#> Datsun 710 0.4668753
#> Hornet 4 Drive 0.4668753
#> Hornet Sportabout 0.2280089
#> Valiant 0.4668753
#> Duster 360 -0.2497240
#> Merc 240D 0.2280089
#> Merc 230 0.2280089
#> Merc 280 -0.2497240
#> attr(,"constant")
#> [1] 19.96364
相關用法
- R parsnip proportional_hazards 比例風險回歸
- R parsnip parsnip_update 更新型號規格
- R parsnip logistic_reg 邏輯回歸
- R parsnip linear_reg 線性回歸
- R parsnip C5_rules C5.0 基於規則的分類模型
- R parsnip set_engine 聲明計算引擎和特定參數
- R parsnip condense_control 將控製對象壓縮為更小的控製對象
- R parsnip control_parsnip 控製擬合函數
- R parsnip augment 通過預測增強數據
- R parsnip repair_call 修複模型調用對象
- R parsnip dot-model_param_name_key 翻譯模型調整參數的名稱
- R parsnip glm_grouped 將數據集中的分組二項式結果與個案權重擬合
- R parsnip rule_fit 規則擬合模型
- R parsnip svm_rbf 徑向基函數支持向量機
- R parsnip set_args 更改模型規範的元素
- R parsnip translate 解決計算引擎的模型規範
- R parsnip max_mtry_formula 根據公式確定 mtry 的最大值。此函數可能會根據公式和數據集限製 mtry 的值。對於生存和/或多變量模型來說,這是一種安全的方法。
- R parsnip svm_linear 線性支持向量機
- R parsnip set_new_model 注冊模型的工具
- R parsnip rand_forest 隨機森林
- R parsnip mlp 單層神經網絡
- R parsnip nearest_neighbor K-最近鄰
- R parsnip fit 將模型規範擬合到數據集
- R parsnip boost_tree 增強樹
- R parsnip bart 貝葉斯加性回歸樹 (BART)
注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Model predictions。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。