當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


R parsnip augment 通過預測增強數據


augment() 將為給定數據添加預測列。

用法

# S3 method for model_fit
augment(x, new_data, eval_time = NULL, ...)

參數

x

fit.model_spec()fit_xy.model_spec() 生成的 model_fit 對象。

new_data

DataFrame 或矩陣。

eval_time

對於審查回歸模型,估計生存概率的時間點向量。

...

目前未使用。

細節

回歸

對於回歸模型,添加 .pred 列。如果 x 是使用 fit.model_spec() 創建的,並且 new_data 包含回歸結果列,則還會添加 .resid 列。

分類

對於分類模型,結果可以包括名為 .pred_class 的列以及名為 .pred_{level} 的類概率列。這取決於模型可用的預測類型。

刪失回歸

對於這些模型,將創建對預期時間和生存概率的預測(如果模型引擎支持它們)。如果模型支持生存預測,則需要 eval_time 參數。

如果創建了生存預測並且new_data包含一個survival::Surv()對象,還添加了額外的列以創建審查權重的逆概率(IPCW)(請參閱tidymodels.org以下參考文獻中的頁麵)。這使得用戶能夠計算性能指標尺度包。

例子

car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]

reg_form <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = car_trn)
reg_xy <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit_xy(car_trn[, -1], car_trn$mpg)

augment(reg_form, car_tst)
#> # A tibble: 10 × 13
#>    .pred .resid   mpg   cyl  disp    hp  drat    wt  qsec    vs    am
#>    <dbl>  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  23.4 -2.43   21       6  160    110  3.9   2.62  16.5     0     1
#>  2  23.3 -2.30   21       6  160    110  3.9   2.88  17.0     0     1
#>  3  27.6 -4.83   22.8     4  108     93  3.85  2.32  18.6     1     1
#>  4  21.5 -0.147  21.4     6  258    110  3.08  3.22  19.4     1     0
#>  5  17.6  1.13   18.7     8  360    175  3.15  3.44  17.0     0     0
#>  6  21.6 -3.48   18.1     6  225    105  2.76  3.46  20.2     1     0
#>  7  13.9  0.393  14.3     8  360    245  3.21  3.57  15.8     0     0
#>  8  21.7  2.70   24.4     4  147.    62  3.69  3.19  20       1     0
#>  9  25.6 -2.81   22.8     4  141.    95  3.92  3.15  22.9     1     0
#> 10  17.1  2.09   19.2     6  168.   123  3.92  3.44  18.3     1     0
#> # ℹ 2 more variables: gear <dbl>, carb <dbl>
augment(reg_form, car_tst[, -1])
#> # A tibble: 10 × 11
#>    .pred   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  23.4     6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  23.3     6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  27.6     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  21.5     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  17.6     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  21.6     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  13.9     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  21.7     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  25.6     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  17.1     6  168.   123  3.92  3.44  18.3     1     0     4     4

augment(reg_xy, car_tst)
#> # A tibble: 10 × 12
#>    .pred   mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  23.4  21       6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  23.3  21       6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  27.6  22.8     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  21.5  21.4     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  17.6  18.7     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  21.6  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  13.9  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  21.7  24.4     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  25.6  22.8     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  17.1  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4
augment(reg_xy, car_tst[, -1])
#> # A tibble: 10 × 11
#>    .pred   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  23.4     6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  23.3     6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  27.6     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  21.5     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  17.6     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  21.6     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  13.9     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  21.7     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  25.6     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  17.1     6  168.   123  3.92  3.44  18.3     1     0     4     4

# ------------------------------------------------------------------------------

data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[  1:10 , ]

cls_form <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit(Class ~ ., data = cls_trn)
cls_xy <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit_xy(cls_trn[, -3],
  cls_trn$Class)

augment(cls_form, cls_tst)
#> # A tibble: 10 × 6
#>    .pred_class .pred_Class1 .pred_Class2     A     B Class 
#>    <fct>              <dbl>        <dbl> <dbl> <dbl> <fct> 
#>  1 Class1             0.518      0.482    2.07 1.63  Class1
#>  2 Class1             0.909      0.0913   2.02 1.04  Class1
#>  3 Class1             0.648      0.352    1.69 1.37  Class2
#>  4 Class1             0.610      0.390    3.43 1.98  Class2
#>  5 Class2             0.443      0.557    2.88 1.98  Class1
#>  6 Class2             0.206      0.794    3.31 2.41  Class2
#>  7 Class1             0.708      0.292    2.50 1.56  Class2
#>  8 Class1             0.567      0.433    1.98 1.55  Class2
#>  9 Class1             0.994      0.00582  2.88 0.580 Class1
#> 10 Class2             0.108      0.892    3.74 2.74  Class2
augment(cls_form, cls_tst[, -3])
#> # A tibble: 10 × 5
#>    .pred_class .pred_Class1 .pred_Class2     A     B
#>    <fct>              <dbl>        <dbl> <dbl> <dbl>
#>  1 Class1             0.518      0.482    2.07 1.63 
#>  2 Class1             0.909      0.0913   2.02 1.04 
#>  3 Class1             0.648      0.352    1.69 1.37 
#>  4 Class1             0.610      0.390    3.43 1.98 
#>  5 Class2             0.443      0.557    2.88 1.98 
#>  6 Class2             0.206      0.794    3.31 2.41 
#>  7 Class1             0.708      0.292    2.50 1.56 
#>  8 Class1             0.567      0.433    1.98 1.55 
#>  9 Class1             0.994      0.00582  2.88 0.580
#> 10 Class2             0.108      0.892    3.74 2.74 

augment(cls_xy, cls_tst)
#> # A tibble: 10 × 6
#>    .pred_class .pred_Class1 .pred_Class2     A     B Class 
#>    <fct>              <dbl>        <dbl> <dbl> <dbl> <fct> 
#>  1 Class1             0.518      0.482    2.07 1.63  Class1
#>  2 Class1             0.909      0.0913   2.02 1.04  Class1
#>  3 Class1             0.648      0.352    1.69 1.37  Class2
#>  4 Class1             0.610      0.390    3.43 1.98  Class2
#>  5 Class2             0.443      0.557    2.88 1.98  Class1
#>  6 Class2             0.206      0.794    3.31 2.41  Class2
#>  7 Class1             0.708      0.292    2.50 1.56  Class2
#>  8 Class1             0.567      0.433    1.98 1.55  Class2
#>  9 Class1             0.994      0.00582  2.88 0.580 Class1
#> 10 Class2             0.108      0.892    3.74 2.74  Class2
augment(cls_xy, cls_tst[, -3])
#> # A tibble: 10 × 5
#>    .pred_class .pred_Class1 .pred_Class2     A     B
#>    <fct>              <dbl>        <dbl> <dbl> <dbl>
#>  1 Class1             0.518      0.482    2.07 1.63 
#>  2 Class1             0.909      0.0913   2.02 1.04 
#>  3 Class1             0.648      0.352    1.69 1.37 
#>  4 Class1             0.610      0.390    3.43 1.98 
#>  5 Class2             0.443      0.557    2.88 1.98 
#>  6 Class2             0.206      0.794    3.31 2.41 
#>  7 Class1             0.708      0.292    2.50 1.56 
#>  8 Class1             0.567      0.433    1.98 1.55 
#>  9 Class1             0.994      0.00582  2.88 0.580
#> 10 Class2             0.108      0.892    3.74 2.74 
源代碼:R/augment.R

相關用法


注:本文由純淨天空篩選整理自Max Kuhn等大神的英文原創作品 Augment data with predictions。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。