当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R parsnip augment 通过预测增强数据


augment() 将为给定数据添加预测列。

用法

# S3 method for model_fit
augment(x, new_data, eval_time = NULL, ...)

参数

x

fit.model_spec()fit_xy.model_spec() 生成的 model_fit 对象。

new_data

DataFrame 或矩阵。

eval_time

对于审查回归模型,估计生存概率的时间点向量。

...

目前未使用。

细节

回归

对于回归模型,添加 .pred 列。如果 x 是使用 fit.model_spec() 创建的,并且 new_data 包含回归结果列,则还会添加 .resid 列。

分类

对于分类模型,结果可以包括名为 .pred_class 的列以及名为 .pred_{level} 的类概率列。这取决于模型可用的预测类型。

删失回归

对于这些模型,将创建对预期时间和生存概率的预测(如果模型引擎支持它们)。如果模型支持生存预测,则需要 eval_time 参数。

如果创建了生存预测并且new_data包含一个survival::Surv()对象,还添加了额外的列以创建审查权重的逆概率(IPCW)(请参阅tidymodels.org以下参考文献中的页面)。这使得用户能够计算性能指标尺度包。

例子

car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]

reg_form <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = car_trn)
reg_xy <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit_xy(car_trn[, -1], car_trn$mpg)

augment(reg_form, car_tst)
#> # A tibble: 10 × 13
#>    .pred .resid   mpg   cyl  disp    hp  drat    wt  qsec    vs    am
#>    <dbl>  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  23.4 -2.43   21       6  160    110  3.9   2.62  16.5     0     1
#>  2  23.3 -2.30   21       6  160    110  3.9   2.88  17.0     0     1
#>  3  27.6 -4.83   22.8     4  108     93  3.85  2.32  18.6     1     1
#>  4  21.5 -0.147  21.4     6  258    110  3.08  3.22  19.4     1     0
#>  5  17.6  1.13   18.7     8  360    175  3.15  3.44  17.0     0     0
#>  6  21.6 -3.48   18.1     6  225    105  2.76  3.46  20.2     1     0
#>  7  13.9  0.393  14.3     8  360    245  3.21  3.57  15.8     0     0
#>  8  21.7  2.70   24.4     4  147.    62  3.69  3.19  20       1     0
#>  9  25.6 -2.81   22.8     4  141.    95  3.92  3.15  22.9     1     0
#> 10  17.1  2.09   19.2     6  168.   123  3.92  3.44  18.3     1     0
#> # ℹ 2 more variables: gear <dbl>, carb <dbl>
augment(reg_form, car_tst[, -1])
#> # A tibble: 10 × 11
#>    .pred   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  23.4     6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  23.3     6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  27.6     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  21.5     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  17.6     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  21.6     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  13.9     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  21.7     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  25.6     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  17.1     6  168.   123  3.92  3.44  18.3     1     0     4     4

augment(reg_xy, car_tst)
#> # A tibble: 10 × 12
#>    .pred   mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  23.4  21       6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  23.3  21       6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  27.6  22.8     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  21.5  21.4     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  17.6  18.7     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  21.6  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  13.9  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  21.7  24.4     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  25.6  22.8     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  17.1  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4
augment(reg_xy, car_tst[, -1])
#> # A tibble: 10 × 11
#>    .pred   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  23.4     6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  23.3     6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  27.6     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  21.5     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  17.6     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  21.6     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  13.9     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  21.7     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  25.6     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  17.1     6  168.   123  3.92  3.44  18.3     1     0     4     4

# ------------------------------------------------------------------------------

data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[  1:10 , ]

cls_form <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit(Class ~ ., data = cls_trn)
cls_xy <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit_xy(cls_trn[, -3],
  cls_trn$Class)

augment(cls_form, cls_tst)
#> # A tibble: 10 × 6
#>    .pred_class .pred_Class1 .pred_Class2     A     B Class 
#>    <fct>              <dbl>        <dbl> <dbl> <dbl> <fct> 
#>  1 Class1             0.518      0.482    2.07 1.63  Class1
#>  2 Class1             0.909      0.0913   2.02 1.04  Class1
#>  3 Class1             0.648      0.352    1.69 1.37  Class2
#>  4 Class1             0.610      0.390    3.43 1.98  Class2
#>  5 Class2             0.443      0.557    2.88 1.98  Class1
#>  6 Class2             0.206      0.794    3.31 2.41  Class2
#>  7 Class1             0.708      0.292    2.50 1.56  Class2
#>  8 Class1             0.567      0.433    1.98 1.55  Class2
#>  9 Class1             0.994      0.00582  2.88 0.580 Class1
#> 10 Class2             0.108      0.892    3.74 2.74  Class2
augment(cls_form, cls_tst[, -3])
#> # A tibble: 10 × 5
#>    .pred_class .pred_Class1 .pred_Class2     A     B
#>    <fct>              <dbl>        <dbl> <dbl> <dbl>
#>  1 Class1             0.518      0.482    2.07 1.63 
#>  2 Class1             0.909      0.0913   2.02 1.04 
#>  3 Class1             0.648      0.352    1.69 1.37 
#>  4 Class1             0.610      0.390    3.43 1.98 
#>  5 Class2             0.443      0.557    2.88 1.98 
#>  6 Class2             0.206      0.794    3.31 2.41 
#>  7 Class1             0.708      0.292    2.50 1.56 
#>  8 Class1             0.567      0.433    1.98 1.55 
#>  9 Class1             0.994      0.00582  2.88 0.580
#> 10 Class2             0.108      0.892    3.74 2.74 

augment(cls_xy, cls_tst)
#> # A tibble: 10 × 6
#>    .pred_class .pred_Class1 .pred_Class2     A     B Class 
#>    <fct>              <dbl>        <dbl> <dbl> <dbl> <fct> 
#>  1 Class1             0.518      0.482    2.07 1.63  Class1
#>  2 Class1             0.909      0.0913   2.02 1.04  Class1
#>  3 Class1             0.648      0.352    1.69 1.37  Class2
#>  4 Class1             0.610      0.390    3.43 1.98  Class2
#>  5 Class2             0.443      0.557    2.88 1.98  Class1
#>  6 Class2             0.206      0.794    3.31 2.41  Class2
#>  7 Class1             0.708      0.292    2.50 1.56  Class2
#>  8 Class1             0.567      0.433    1.98 1.55  Class2
#>  9 Class1             0.994      0.00582  2.88 0.580 Class1
#> 10 Class2             0.108      0.892    3.74 2.74  Class2
augment(cls_xy, cls_tst[, -3])
#> # A tibble: 10 × 5
#>    .pred_class .pred_Class1 .pred_Class2     A     B
#>    <fct>              <dbl>        <dbl> <dbl> <dbl>
#>  1 Class1             0.518      0.482    2.07 1.63 
#>  2 Class1             0.909      0.0913   2.02 1.04 
#>  3 Class1             0.648      0.352    1.69 1.37 
#>  4 Class1             0.610      0.390    3.43 1.98 
#>  5 Class2             0.443      0.557    2.88 1.98 
#>  6 Class2             0.206      0.794    3.31 2.41 
#>  7 Class1             0.708      0.292    2.50 1.56 
#>  8 Class1             0.567      0.433    1.98 1.55 
#>  9 Class1             0.994      0.00582  2.88 0.580
#> 10 Class2             0.108      0.892    3.74 2.74 
源代码:R/augment.R

相关用法


注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Augment data with predictions。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。