當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


R parsnip predict.model_fit 模型預測


應用模型來創建不同類型的預測。 predict() 可用於所有類型的模型,並使用 "type" 參數以獲得更多特異性。

用法

# S3 method for model_fit
predict(object, new_data, type = NULL, opts = list(), ...)

# S3 method for model_fit
predict_raw(object, new_data, opts = list(), ...)

predict_raw(object, ...)

參數

object

model_fit 的對象。

new_data

矩形數據對象,例如 DataFrame 。

type

單個字符值或 NULL 。可能的值為 "numeric" , "class" , "prob" , "conf_int" , "pred_int" , "quantile" , "time" , "hazard" , "survival""raw" 。當 NULL 時,predict() 會根據模型的模式選擇合適的值。

opts

type = "raw" 時將使用的基礎預測函數的可選參數列表。該列表不應包含模型對象或正在預測的新數據的選項。

...

其他 parsnip 相關選項,具體取決於 type 的值。無法在此處傳遞底層模型預測函數的參數(請改用opts 參數)。可能的論點是:

  • interval :對於 type 等於 "survival""quantile" ,是否應該添加間隔估計(如果有)?選項是 "none""confidence"

  • level :對於 type 等於 "conf_int""pred_int""survival" ,這是間隔尾部區域的參數(例如置信區間的置信水平)。默認值為0.95

  • std_error :對於 type 等於 "conf_int""pred_int" ,添加擬合或預測的標準誤差(在線性預測變量的範圍內)。默認值為FALSE

  • quantile :對於 type 等於 quantile ,分布的分位數。默認為 (1:9)/10

  • eval_time :對於 type 等於 "survival""hazard" ,估計生存概率或危險的時間點。

除了 type = "raw" 之外,predict.model_fit() 的結果

  • 是一個小詞

  • 行數與 new_data 中的行數一樣多

  • 具有標準化的列名稱,如下所示:

對於 type = "numeric" ,tibble 具有用於單個結果的 .pred 列和用於多變量結果的 .pred_Yname 列。

對於 type = "class" ,tibble 有一個 .pred_class 列。

對於 type = "prob" ,標題具有 .pred_classlevel 列。

對於 type = "conf_int"type = "pred_int" ,tibble 具有帶有置信度屬性的 .pred_lower.pred_upper 列。在可以為類概率(或其他非標量輸出)生成間隔的情況下,列被命名為 .pred_lower_classlevel 等。

對於 type = "quantile" ,tibble 有一個 .pred 列,它是一個列表列。每個列表元素都包含一個帶有列 .pred.quantile(可能還有其他列)的 tibble。

對於 type = "time" ,tibble 有一個 .pred_time 列。

對於 type = "survival" ,tibble 有一個 .pred 列,它是一個列表列。每個列表元素都包含一個帶有列 .eval_time.pred_survival(可能還有其他列)的 tibble。

對於 type = "hazard" ,tibble 有一個 .pred 列,它是一個列表列。每個列表元素都包含一個帶有列 .eval_time.pred_hazard(可能還有其他列)的 tibble。

type = "raw"predict.model_fit() 結合使用將返回預測函數的純正結果。

對於基於 Spark 的模型,由於表列不能包含點,因此使用相同的約定,除了 1) 名稱中不出現點和 2) 永遠不會返回向量,但返回 type-specific 預測函數。

當模型擬合失敗並捕獲錯誤時,predict() 函數將返回與上述相同的結構,但填充缺失值。目前這不適用於多變量模型。

細節

對於 type = NULLpredict() 使用

  • type = "numeric" 用於回歸模型,

  • type = "class" 用於分類,以及

  • type = "time" 用於審查回歸。

區間預測

使用type = "conf_int"type = "pred_int"時,可以使用選項levelstd_error。後者是標準錯誤值的額外列(如果可用)的邏輯。

刪失回歸預測

對於審查回歸,當請求生存或危險概率時,需要 eval_time 的數值向量。時間值必須是唯一的、有限的、非缺失的且非負的。 predict() 函數將通過刪除違規點(帶有警告)來調整值以適應此規範。

predict.model_fit() 不要求存在結果。對於預測生存概率的性能指標,需要審查權重的逆概率 (IPCW)(請參閱下麵的 tidymodels.org 參考)。這些需要結果,因此 predict() 不會返回。如果 new_data 包含結果為 Surv 對象的列,則可以通過 augment.model_fit() 添加它們。

此外,當 type = "linear_pred" 時,截尾回歸模型將默認格式化,使得線性預測變量隨時間增加。這可能與底層模型的 predict() 方法產生的符號相反。設置increasing = FALSE 以抑製此行為。

例子

library(dplyr)

lm_model <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32))

pred_cars <-
  mtcars %>%
  dplyr::slice(1:10) %>%
  dplyr::select(-mpg)

predict(lm_model, pred_cars)
#> # A tibble: 10 × 1
#>    .pred
#>    <dbl>
#>  1  23.4
#>  2  23.3
#>  3  27.6
#>  4  21.5
#>  5  17.6
#>  6  21.6
#>  7  13.9
#>  8  21.7
#>  9  25.6
#> 10  17.1

predict(
  lm_model,
  pred_cars,
  type = "conf_int",
  level = 0.90
)
#> # A tibble: 10 × 2
#>    .pred_lower .pred_upper
#>          <dbl>       <dbl>
#>  1       17.9         29.0
#>  2       18.1         28.5
#>  3       24.0         31.3
#>  4       17.5         25.6
#>  5       14.3         20.8
#>  6       17.0         26.2
#>  7        9.65        18.2
#>  8       16.2         27.2
#>  9       14.2         37.0
#> 10       11.5         22.7

predict(
  lm_model,
  pred_cars,
  type = "raw",
  opts = list(type = "terms")
)
#>                            cyl       disp         hp        drat
#> Mazda RX4         -0.001433177 -0.8113275  0.6303467 -0.06120265
#> Mazda RX4 Wag     -0.001433177 -0.8113275  0.6303467 -0.06120265
#> Datsun 710        -0.009315653 -1.3336453  0.8557288 -0.05014798
#> Hornet 4 Drive    -0.001433177  0.1730406  0.6303467  0.12009386
#> Hornet Sportabout  0.006449298  1.1975870 -0.2314083  0.10461733
#> Valiant           -0.001433177 -0.1584303  0.6966356  0.19084372
#> Duster 360         0.006449298  1.1975870 -1.1594522  0.09135173
#> Merc 240D         -0.009315653 -0.9449204  1.2667197 -0.01477305
#> Merc 230          -0.009315653 -1.0041833  0.8292133 -0.06562451
#> Merc 280          -0.001433177 -0.7349888  0.4579957 -0.06562451
#>                           wt      qsec         vs       am        gear
#> Mazda RX4          2.4139815 -1.567729  0.2006406  2.88774  0.02512680
#> Mazda RX4 Wag      1.4488706 -0.736286  0.2006406  2.88774  0.02512680
#> Datsun 710         3.5494061  1.624418 -0.3511210  2.88774  0.02512680
#> Hornet 4 Drive     0.1620561  2.856736 -0.3511210 -2.40645 -0.06700481
#> Hornet Sportabout -0.6895124 -0.736286  0.2006406 -2.40645 -0.06700481
#> Valiant           -0.7652074  4.014817 -0.3511210 -2.40645 -0.06700481
#> Duster 360        -1.1815297 -2.488255  0.2006406 -2.40645 -0.06700481
#> Merc 240D          0.2566748  3.688179 -0.3511210 -2.40645  0.02512680
#> Merc 230           0.4080647  7.993866 -0.3511210 -2.40645  0.02512680
#> Merc 280          -0.6895124  1.164155 -0.3511210 -2.40645  0.02512680
#>                         carb
#> Mazda RX4         -0.2497240
#> Mazda RX4 Wag     -0.2497240
#> Datsun 710         0.4668753
#> Hornet 4 Drive     0.4668753
#> Hornet Sportabout  0.2280089
#> Valiant            0.4668753
#> Duster 360        -0.2497240
#> Merc 240D          0.2280089
#> Merc 230           0.2280089
#> Merc 280          -0.2497240
#> attr(,"constant")
#> [1] 19.96364

相關用法


注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Model predictions。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。