当前位置: 首页>>代码示例>>Python>>正文


Python _search.BaseSearchCV方法代码示例

本文整理汇总了Python中sklearn.model_selection._search.BaseSearchCV方法的典型用法代码示例。如果您正苦于以下问题:Python _search.BaseSearchCV方法的具体用法?Python _search.BaseSearchCV怎么用?Python _search.BaseSearchCV使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在sklearn.model_selection._search的用法示例。


在下文中一共展示了_search.BaseSearchCV方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test__custom_fit_no_run_search

# 需要导入模块: from sklearn.model_selection import _search [as 别名]
# 或者: from sklearn.model_selection._search import BaseSearchCV [as 别名]
def test__custom_fit_no_run_search():
    class NoRunSearchSearchCV(BaseSearchCV):
        def __init__(self, estimator, **kwargs):
            super().__init__(estimator, **kwargs)

        def fit(self, X, y=None, groups=None, **fit_params):
            return self

    # this should not raise any exceptions
    NoRunSearchSearchCV(SVC(), cv=5).fit(X, y)

    class BadSearchCV(BaseSearchCV):
        def __init__(self, estimator, **kwargs):
            super().__init__(estimator, **kwargs)

    with pytest.raises(NotImplementedError,
                       match="_run_search not implemented."):
        # this should raise a NotImplementedError
        BadSearchCV(SVC(), cv=5).fit(X, y) 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:21,代码来源:test_search.py

示例2: print_cv_result

# 需要导入模块: from sklearn.model_selection import _search [as 别名]
# 或者: from sklearn.model_selection._search import BaseSearchCV [as 别名]
def print_cv_result(result, n):
    if isinstance(result, BaseSearchCV):
        result = result.cv_results_

    scores = result['mean_test_score']
    params = result['params']

    if n < 0:
        n = len(scores)

    print("Cross Validation result in descending order: (totalling {} trials)".format(n))
    for rank, candidate, in enumerate(heapq.nlargest(n, zip(scores, params), key=lambda tup: tup[0])):
        print("rank {}, score = {}\n hyperparams = {}".format(rank + 1, *candidate)) 
开发者ID:Johnny-Wish,项目名称:fake-news-detection-pipeline,代码行数:15,代码来源:__main__.py

示例3: test_custom_run_search

# 需要导入模块: from sklearn.model_selection import _search [as 别名]
# 或者: from sklearn.model_selection._search import BaseSearchCV [as 别名]
def test_custom_run_search():
    def check_results(results, gscv):
        exp_results = gscv.cv_results_
        assert sorted(results.keys()) == sorted(exp_results)
        for k in results:
            if not k.endswith('_time'):
                # XXX: results['params'] is a list :|
                results[k] = np.asanyarray(results[k])
                if results[k].dtype.kind == 'O':
                    assert_array_equal(exp_results[k], results[k],
                                       err_msg='Checking ' + k)
                else:
                    assert_allclose(exp_results[k], results[k],
                                    err_msg='Checking ' + k)

    def fit_grid(param_grid):
        return GridSearchCV(clf, param_grid, cv=5,
                            return_train_score=True).fit(X, y)

    class CustomSearchCV(BaseSearchCV):
        def __init__(self, estimator, **kwargs):
            super().__init__(estimator, **kwargs)

        def _run_search(self, evaluate):
            results = evaluate([{'max_depth': 1}, {'max_depth': 2}])
            check_results(results, fit_grid({'max_depth': [1, 2]}))
            results = evaluate([{'min_samples_split': 5},
                                {'min_samples_split': 10}])
            check_results(results, fit_grid([{'max_depth': [1, 2]},
                                             {'min_samples_split': [5, 10]}]))

    # Using regressor to make sure each score differs
    clf = DecisionTreeRegressor(random_state=0)
    X, y = make_classification(n_samples=100, n_informative=4,
                               random_state=0)
    mycv = CustomSearchCV(clf, cv=5, return_train_score=True).fit(X, y)
    gscv = fit_grid([{'max_depth': [1, 2]},
                     {'min_samples_split': [5, 10]}])

    results = mycv.cv_results_
    check_results(results, gscv)
    for attr in dir(gscv):
        if attr[0].islower() and attr[-1:] == '_' and \
           attr not in {'cv_results_', 'best_estimator_',
                        'refit_time_'}:
            assert getattr(gscv, attr) == getattr(mycv, attr), \
                   "Attribute %s not equal" % attr 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:49,代码来源:test_search.py


注:本文中的sklearn.model_selection._search.BaseSearchCV方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。