当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python sklearn ClassifierChain用法及代码示例


本文简要介绍python语言中 sklearn.multioutput.ClassifierChain 的用法。

用法:

class sklearn.multioutput.ClassifierChain(base_estimator, *, order=None, cv=None, random_state=None)

将二元分类器排列成链的多标签模型。

每个模型使用提供给模型的所有可用特征加上链中较早模型的预测,按照链指定的顺序进行预测。

在用户指南中阅读更多信息。

参数

base_estimator估计器

构建分类器链的基本估计器。

order形状类似数组 (n_outputs,) 或 ‘random’,默认=无

如果 None ,则顺序将由标签矩阵 Y 中的列顺序决定:

order = [0, 1, 2, ..., Y.shape[1] - 1]

可以通过提供整数列表来显式设置链的顺序。例如,对于长度为 5 的链:

order = [1, 3, 2, 4, 0]

意味着链中的第一个模型将对 Y 矩阵中的第 1 列进行预测,第二个模型将对第 3 列进行预测,依此类推。

如果 order 是random,将使用随机排序。

cvint,交叉验证生成器或可迭代的,默认=无

确定是否对链中先前估计器的结果使用交叉验证的预测或真实标签。 cv 的可能输入是:

  • 无,在拟合时使用真实标签,
  • 整数,指定(分层)KFold 中的折叠数,
  • CV分配器,
  • 一个可迭代的 yield (train, test) 拆分为索引数组。
random_stateint、RandomState 实例或无,可选(默认=无)

如果 order='random' ,则确定链顺序的随机数生成。此外,它还控制每次链接迭代时每个 base_estimator 给出的随机种子。因此,仅当 base_estimator 公开 random_state 时才使用它。传递 int 以在多个函数调用之间实现可重现的输出。请参阅术语表。

属性

classes_列表

长度为 len(estimators_) 的数组列表,其中包含链中每个估计器的类标签。

estimators_列表

base_estimator 的克隆列表。

order_列表

分类器链中标签的顺序。

n_features_in_int

拟合期间看到的特征数。仅当底层 base_estimator 在合适时公开此类属性时才定义。

feature_names_in_ndarray 形状(n_features_in_,)

拟合期间看到的特征名称。仅当 X 具有全为字符串的函数名称时才定义。

参考

Jesse Read、Bernhard Pfahringer、Geoff Holmes、Eibe Frank,“用于多标签分类的分类器链”,2009 年。

例子

>>> from sklearn.datasets import make_multilabel_classification
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.multioutput import ClassifierChain
>>> X, Y = make_multilabel_classification(
...    n_samples=12, n_classes=3, random_state=0
... )
>>> X_train, X_test, Y_train, Y_test = train_test_split(
...    X, Y, random_state=0
... )
>>> base_lr = LogisticRegression(solver='lbfgs', random_state=0)
>>> chain = ClassifierChain(base_lr, order='random', random_state=0)
>>> chain.fit(X_train, Y_train).predict(X_test)
array([[1., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.]])
>>> chain.predict_proba(X_test)
array([[0.8387..., 0.9431..., 0.4576...],
       [0.8878..., 0.3684..., 0.2640...],
       [0.0321..., 0.9935..., 0.0625...]])

相关用法


注:本文由纯净天空筛选整理自scikit-learn.org大神的英文原创作品 sklearn.multioutput.ClassifierChain。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。