应用模型来创建不同类型的预测。 predict()
可用于所有类型的模型,并使用 "type" 参数以获得更多特异性。
用法
# S3 method for model_fit
predict(object, new_data, type = NULL, opts = list(), ...)
# S3 method for model_fit
predict_raw(object, new_data, opts = list(), ...)
predict_raw(object, ...)
参数
- object
-
类
model_fit
的对象。 - new_data
-
矩形数据对象,例如 DataFrame 。
- type
-
单个字符值或
NULL
。可能的值为"numeric"
,"class"
,"prob"
,"conf_int"
,"pred_int"
,"quantile"
,"time"
,"hazard"
,"survival"
或"raw"
。当NULL
时,predict()
会根据模型的模式选择合适的值。 - opts
-
type = "raw"
时将使用的基础预测函数的可选参数列表。该列表不应包含模型对象或正在预测的新数据的选项。 - ...
-
其他
parsnip
相关选项,具体取决于type
的值。无法在此处传递底层模型预测函数的参数(请改用opts
参数)。可能的论点是:-
interval
:对于type
等于"survival"
或"quantile"
,是否应该添加间隔估计(如果有)?选项是"none"
和"confidence"
。 -
level
:对于type
等于"conf_int"
、"pred_int"
或"survival"
,这是间隔尾部区域的参数(例如置信区间的置信水平)。默认值为0.95
。 -
std_error
:对于type
等于"conf_int"
或"pred_int"
,添加拟合或预测的标准误差(在线性预测变量的范围内)。默认值为FALSE
。 -
quantile
:对于type
等于quantile
,分布的分位数。默认为(1:9)/10
。 -
eval_time
:对于type
等于"survival"
或"hazard"
,估计生存概率或危险的时间点。
-
值
除了 type = "raw"
之外,predict.model_fit()
的结果
-
是一个小词
-
行数与
new_data
中的行数一样多 -
具有标准化的列名称,如下所示:
对于 type = "numeric"
,tibble 具有用于单个结果的 .pred
列和用于多变量结果的 .pred_Yname
列。
对于 type = "class"
,tibble 有一个 .pred_class
列。
对于 type = "prob"
,标题具有 .pred_classlevel
列。
对于 type = "conf_int"
和 type = "pred_int"
,tibble 具有带有置信度属性的 .pred_lower
和 .pred_upper
列。在可以为类概率(或其他非标量输出)生成间隔的情况下,列被命名为 .pred_lower_classlevel
等。
对于 type = "quantile"
,tibble 有一个 .pred
列,它是一个列表列。每个列表元素都包含一个带有列 .pred
和 .quantile
(可能还有其他列)的 tibble。
对于 type = "time"
,tibble 有一个 .pred_time
列。
对于 type = "survival"
,tibble 有一个 .pred
列,它是一个列表列。每个列表元素都包含一个带有列 .eval_time
和 .pred_survival
(可能还有其他列)的 tibble。
对于 type = "hazard"
,tibble 有一个 .pred
列,它是一个列表列。每个列表元素都包含一个带有列 .eval_time
和 .pred_hazard
(可能还有其他列)的 tibble。
将 type = "raw"
与 predict.model_fit()
结合使用将返回预测函数的纯正结果。
对于基于 Spark 的模型,由于表列不能包含点,因此使用相同的约定,除了 1) 名称中不出现点和 2) 永远不会返回向量,但返回 type-specific 预测函数。
当模型拟合失败并捕获错误时,predict()
函数将返回与上述相同的结构,但填充缺失值。目前这不适用于多变量模型。
细节
对于 type = NULL
, predict()
使用
-
type = "numeric"
用于回归模型, -
type = "class"
用于分类,以及 -
type = "time"
用于审查回归。
删失回归预测
对于审查回归,当请求生存或危险概率时,需要 eval_time
的数值向量。时间值必须是唯一的、有限的、非缺失的且非负的。 predict()
函数将通过删除违规点(带有警告)来调整值以适应此规范。
predict.model_fit()
不要求存在结果。对于预测生存概率的性能指标,需要审查权重的逆概率 (IPCW)(请参阅下面的 tidymodels.org
参考)。这些需要结果,因此 predict()
不会返回。如果 new_data
包含结果为 Surv
对象的列,则可以通过 augment.model_fit()
添加它们。
此外,当 type = "linear_pred"
时,截尾回归模型将默认格式化,使得线性预测变量随时间增加。这可能与底层模型的 predict()
方法产生的符号相反。设置increasing = FALSE
以抑制此行为。
例子
library(dplyr)
lm_model <-
linear_reg() %>%
set_engine("lm") %>%
fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32))
pred_cars <-
mtcars %>%
dplyr::slice(1:10) %>%
dplyr::select(-mpg)
predict(lm_model, pred_cars)
#> # A tibble: 10 × 1
#> .pred
#> <dbl>
#> 1 23.4
#> 2 23.3
#> 3 27.6
#> 4 21.5
#> 5 17.6
#> 6 21.6
#> 7 13.9
#> 8 21.7
#> 9 25.6
#> 10 17.1
predict(
lm_model,
pred_cars,
type = "conf_int",
level = 0.90
)
#> # A tibble: 10 × 2
#> .pred_lower .pred_upper
#> <dbl> <dbl>
#> 1 17.9 29.0
#> 2 18.1 28.5
#> 3 24.0 31.3
#> 4 17.5 25.6
#> 5 14.3 20.8
#> 6 17.0 26.2
#> 7 9.65 18.2
#> 8 16.2 27.2
#> 9 14.2 37.0
#> 10 11.5 22.7
predict(
lm_model,
pred_cars,
type = "raw",
opts = list(type = "terms")
)
#> cyl disp hp drat
#> Mazda RX4 -0.001433177 -0.8113275 0.6303467 -0.06120265
#> Mazda RX4 Wag -0.001433177 -0.8113275 0.6303467 -0.06120265
#> Datsun 710 -0.009315653 -1.3336453 0.8557288 -0.05014798
#> Hornet 4 Drive -0.001433177 0.1730406 0.6303467 0.12009386
#> Hornet Sportabout 0.006449298 1.1975870 -0.2314083 0.10461733
#> Valiant -0.001433177 -0.1584303 0.6966356 0.19084372
#> Duster 360 0.006449298 1.1975870 -1.1594522 0.09135173
#> Merc 240D -0.009315653 -0.9449204 1.2667197 -0.01477305
#> Merc 230 -0.009315653 -1.0041833 0.8292133 -0.06562451
#> Merc 280 -0.001433177 -0.7349888 0.4579957 -0.06562451
#> wt qsec vs am gear
#> Mazda RX4 2.4139815 -1.567729 0.2006406 2.88774 0.02512680
#> Mazda RX4 Wag 1.4488706 -0.736286 0.2006406 2.88774 0.02512680
#> Datsun 710 3.5494061 1.624418 -0.3511210 2.88774 0.02512680
#> Hornet 4 Drive 0.1620561 2.856736 -0.3511210 -2.40645 -0.06700481
#> Hornet Sportabout -0.6895124 -0.736286 0.2006406 -2.40645 -0.06700481
#> Valiant -0.7652074 4.014817 -0.3511210 -2.40645 -0.06700481
#> Duster 360 -1.1815297 -2.488255 0.2006406 -2.40645 -0.06700481
#> Merc 240D 0.2566748 3.688179 -0.3511210 -2.40645 0.02512680
#> Merc 230 0.4080647 7.993866 -0.3511210 -2.40645 0.02512680
#> Merc 280 -0.6895124 1.164155 -0.3511210 -2.40645 0.02512680
#> carb
#> Mazda RX4 -0.2497240
#> Mazda RX4 Wag -0.2497240
#> Datsun 710 0.4668753
#> Hornet 4 Drive 0.4668753
#> Hornet Sportabout 0.2280089
#> Valiant 0.4668753
#> Duster 360 -0.2497240
#> Merc 240D 0.2280089
#> Merc 230 0.2280089
#> Merc 280 -0.2497240
#> attr(,"constant")
#> [1] 19.96364
相关用法
- R parsnip proportional_hazards 比例风险回归
- R parsnip parsnip_update 更新型号规格
- R parsnip logistic_reg 逻辑回归
- R parsnip linear_reg 线性回归
- R parsnip C5_rules C5.0 基于规则的分类模型
- R parsnip set_engine 声明计算引擎和特定参数
- R parsnip condense_control 将控制对象压缩为更小的控制对象
- R parsnip control_parsnip 控制拟合函数
- R parsnip augment 通过预测增强数据
- R parsnip repair_call 修复模型调用对象
- R parsnip dot-model_param_name_key 翻译模型调整参数的名称
- R parsnip glm_grouped 将数据集中的分组二项式结果与个案权重拟合
- R parsnip rule_fit 规则拟合模型
- R parsnip svm_rbf 径向基函数支持向量机
- R parsnip set_args 更改模型规范的元素
- R parsnip translate 解决计算引擎的模型规范
- R parsnip max_mtry_formula 根据公式确定 mtry 的最大值。此函数可能会根据公式和数据集限制 mtry 的值。对于生存和/或多变量模型来说,这是一种安全的方法。
- R parsnip svm_linear 线性支持向量机
- R parsnip set_new_model 注册模型的工具
- R parsnip rand_forest 随机森林
- R parsnip mlp 单层神经网络
- R parsnip nearest_neighbor K-最近邻
- R parsnip fit 将模型规范拟合到数据集
- R parsnip boost_tree 增强树
- R parsnip bart 贝叶斯加性回归树 (BART)
注:本文由纯净天空筛选整理自Max Kuhn等大神的英文原创作品 Model predictions。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。