当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


R parsnip predict.model_fit 模型预测


应用模型来创建不同类型的预测。 predict() 可用于所有类型的模型,并使用 "type" 参数以获得更多特异性。

用法

# S3 method for model_fit
predict(object, new_data, type = NULL, opts = list(), ...)

# S3 method for model_fit
predict_raw(object, new_data, opts = list(), ...)

predict_raw(object, ...)

参数

object

model_fit 的对象。

new_data

矩形数据对象,例如 DataFrame 。

type

单个字符值或 NULL 。可能的值为 "numeric" , "class" , "prob" , "conf_int" , "pred_int" , "quantile" , "time" , "hazard" , "survival""raw" 。当 NULL 时,predict() 会根据模型的模式选择合适的值。

opts

type = "raw" 时将使用的基础预测函数的可选参数列表。该列表不应包含模型对象或正在预测的新数据的选项。

...

其他 parsnip 相关选项,具体取决于 type 的值。无法在此处传递底层模型预测函数的参数(请改用opts 参数)。可能的论点是:

  • interval :对于 type 等于 "survival""quantile" ,是否应该添加间隔估计(如果有)?选项是 "none""confidence"

  • level :对于 type 等于 "conf_int""pred_int""survival" ,这是间隔尾部区域的参数(例如置信区间的置信水平)。默认值为0.95

  • std_error :对于 type 等于 "conf_int""pred_int" ,添加拟合或预测的标准误差(在线性预测变量的范围内)。默认值为FALSE

  • quantile :对于 type 等于 quantile ,分布的分位数。默认为 (1:9)/10

  • eval_time :对于 type 等于 "survival""hazard" ,估计生存概率或危险的时间点。

除了 type = "raw" 之外,predict.model_fit() 的结果

  • 是一个小词

  • 行数与 new_data 中的行数一样多

  • 具有标准化的列名称,如下所示:

对于 type = "numeric" ,tibble 具有用于单个结果的 .pred 列和用于多变量结果的 .pred_Yname 列。

对于 type = "class" ,tibble 有一个 .pred_class 列。

对于 type = "prob" ,标题具有 .pred_classlevel 列。

对于 type = "conf_int"type = "pred_int" ,tibble 具有带有置信度属性的 .pred_lower.pred_upper 列。在可以为类概率(或其他非标量输出)生成间隔的情况下,列被命名为 .pred_lower_classlevel 等。

对于 type = "quantile" ,tibble 有一个 .pred 列,它是一个列表列。每个列表元素都包含一个带有列 .pred.quantile(可能还有其他列)的 tibble。

对于 type = "time" ,tibble 有一个 .pred_time 列。

对于 type = "survival" ,tibble 有一个 .pred 列,它是一个列表列。每个列表元素都包含一个带有列 .eval_time.pred_survival(可能还有其他列)的 tibble。

对于 type = "hazard" ,tibble 有一个 .pred 列,它是一个列表列。每个列表元素都包含一个带有列 .eval_time.pred_hazard(可能还有其他列)的 tibble。

type = "raw"predict.model_fit() 结合使用将返回预测函数的纯正结果。

对于基于 Spark 的模型,由于表列不能包含点,因此使用相同的约定,除了 1) 名称中不出现点和 2) 永远不会返回向量,但返回 type-specific 预测函数。

当模型拟合失败并捕获错误时,predict() 函数将返回与上述相同的结构,但填充缺失值。目前这不适用于多变量模型。

细节

对于 type = NULLpredict() 使用

  • type = "numeric" 用于回归模型,

  • type = "class" 用于分类,以及

  • type = "time" 用于审查回归。

区间预测

使用type = "conf_int"type = "pred_int"时,可以使用选项levelstd_error。后者是标准错误值的额外列(如果可用)的逻辑。

删失回归预测

对于审查回归,当请求生存或危险概率时,需要 eval_time 的数值向量。时间值必须是唯一的、有限的、非缺失的且非负的。 predict() 函数将通过删除违规点(带有警告)来调整值以适应此规范。

predict.model_fit() 不要求存在结果。对于预测生存概率的性能指标,需要审查权重的逆概率 (IPCW)(请参阅下面的 tidymodels.org 参考)。这些需要结果,因此 predict() 不会返回。如果 new_data 包含结果为 Surv 对象的列,则可以通过 augment.model_fit() 添加它们。

此外,当 type = "linear_pred" 时,截尾回归模型将默认格式化,使得线性预测变量随时间增加。这可能与底层模型的 predict() 方法产生的符号相反。设置increasing = FALSE 以抑制此行为。

例子

library(dplyr)

lm_model <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32))

pred_cars <-
  mtcars %>%
  dplyr::slice(1:10) %>%
  dplyr::select(-mpg)

predict(lm_model, pred_cars)
#> # A tibble: 10 × 1
#>    .pred
#>    <dbl>
#>  1  23.4
#>  2  23.3
#>  3  27.6
#>  4  21.5
#>  5  17.6
#>  6  21.6
#>  7  13.9
#>  8  21.7
#>  9  25.6
#> 10  17.1

predict(
  lm_model,
  pred_cars,
  type = "conf_int",
  level = 0.90
)
#> # A tibble: 10 × 2
#>    .pred_lower .pred_upper
#>          <dbl>       <dbl>
#>  1       17.9         29.0
#>  2       18.1         28.5
#>  3       24.0         31.3
#>  4       17.5         25.6
#>  5       14.3         20.8
#>  6       17.0         26.2
#>  7        9.65        18.2
#>  8       16.2         27.2
#>  9       14.2         37.0
#> 10       11.5         22.7

predict(
  lm_model,
  pred_cars,
  type = "raw",
  opts = list(type = "terms")
)
#>                            cyl       disp         hp        drat
#> Mazda RX4         -0.001433177 -0.8113275  0.6303467 -0.06120265
#> Mazda RX4 Wag     -0.001433177 -0.8113275  0.6303467 -0.06120265
#> Datsun 710        -0.009315653 -1.3336453  0.8557288 -0.05014798
#> Hornet 4 Drive    -0.001433177  0.1730406  0.6303467  0.12009386
#> Hornet Sportabout  0.006449298  1.1975870 -0.2314083  0.10461733
#> Valiant           -0.001433177 -0.1584303  0.6966356  0.19084372
#> Duster 360         0.006449298  1.1975870 -1.1594522  0.09135173
#> Merc 240D         -0.009315653 -0.9449204  1.2667197 -0.01477305
#> Merc 230          -0.009315653 -1.0041833  0.8292133 -0.06562451
#> Merc 280          -0.001433177 -0.7349888  0.4579957 -0.06562451
#>                           wt      qsec         vs       am        gear
#> Mazda RX4          2.4139815 -1.567729  0.2006406  2.88774  0.02512680
#> Mazda RX4 Wag      1.4488706 -0.736286  0.2006406  2.88774  0.02512680
#> Datsun 710         3.5494061  1.624418 -0.3511210  2.88774  0.02512680
#> Hornet 4 Drive     0.1620561  2.856736 -0.3511210 -2.40645 -0.06700481
#> Hornet Sportabout -0.6895124 -0.736286  0.2006406 -2.40645 -0.06700481
#> Valiant           -0.7652074  4.014817 -0.3511210 -2.40645 -0.06700481
#> Duster 360        -1.1815297 -2.488255  0.2006406 -2.40645 -0.06700481
#> Merc 240D          0.2566748  3.688179 -0.3511210 -2.40645  0.02512680
#> Merc 230           0.4080647  7.993866 -0.3511210 -2.40645  0.02512680
#> Merc 280          -0.6895124  1.164155 -0.3511210 -2.40645  0.02512680
#>                         carb
#> Mazda RX4         -0.2497240
#> Mazda RX4 Wag     -0.2497240
#> Datsun 710         0.4668753
#> Hornet 4 Drive     0.4668753
#> Hornet Sportabout  0.2280089
#> Valiant            0.4668753
#> Duster 360        -0.2497240
#> Merc 240D          0.2280089
#> Merc 230           0.2280089
#> Merc 280          -0.2497240
#> attr(,"constant")
#> [1] 19.96364

相关用法


注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Model predictions。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。