当前位置: 首页>>代码示例>>Python>>正文


Python FistaClassifier.fit方法代码示例

本文整理汇总了Python中lightning.classification.FistaClassifier.fit方法的典型用法代码示例。如果您正苦于以下问题:Python FistaClassifier.fit方法的具体用法?Python FistaClassifier.fit怎么用?Python FistaClassifier.fit使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在lightning.classification.FistaClassifier的用法示例。


在下文中一共展示了FistaClassifier.fit方法的11个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_fista_custom_prox

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def test_fista_custom_prox():
    # test FISTA with a custom prox
    l1_pen = L1Penalty()
    for data in (bin_dense, bin_csr):
        clf = FistaClassifier(max_iter=500, penalty="l1", max_steps=0)
        clf.fit(data, bin_target)

        clf2 = FistaClassifier(max_iter=500, penalty=l1_pen, max_steps=0)
        clf2.fit(data, bin_target)
        np.testing.assert_array_almost_equal_nulp(clf.coef_.ravel(), clf2.coef_.ravel())
开发者ID:evgchz,项目名称:lightning,代码行数:12,代码来源:test_fista.py

示例2: test_fista_multiclass_tv1d

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def test_fista_multiclass_tv1d():
    for data in (mult_dense, mult_csr):
        clf = FistaClassifier(max_iter=200, penalty="tv1d", multiclass=True)
        clf.fit(data, mult_target)
        assert_almost_equal(clf.score(data, mult_target), 0.97, 2)

        # adding a lot of regularization coef_ should be constant
        clf = FistaClassifier(max_iter=200, penalty="tv1d", multiclass=True, alpha=1e6)
        clf.fit(data, mult_target)
        for i in range(clf.coef_.shape[0]):
            np.testing.assert_array_almost_equal(
                clf.coef_[i], np.mean(clf.coef_[i]) * np.ones(data.shape[1]))
开发者ID:evgchz,项目名称:lightning,代码行数:14,代码来源:test_fista.py

示例3: test_fista_multiclass_trace

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def test_fista_multiclass_trace():
    for data in (mult_dense, mult_csr):
        clf = FistaClassifier(max_iter=100, penalty="trace", multiclass=True)
        clf.fit(data, mult_target)
        assert_almost_equal(clf.score(data, mult_target), 0.98, 2)
开发者ID:FedericoV,项目名称:lightning,代码行数:7,代码来源:test_fista.py

示例4: test_fista_bin_l1_no_line_search

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def test_fista_bin_l1_no_line_search():
    for data in (bin_dense, bin_csr):
        clf = FistaClassifier(max_iter=500, penalty="l1", max_steps=0)
        clf.fit(data, bin_target)
        assert_almost_equal(clf.score(data, bin_target), 1.0, 2)
开发者ID:FedericoV,项目名称:lightning,代码行数:7,代码来源:test_fista.py

示例5: test_fista_bin_l1

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def test_fista_bin_l1():
    for data in (bin_dense, bin_csr):
        clf = FistaClassifier(max_iter=200, penalty="l1")
        clf.fit(data, bin_target)
        assert_almost_equal(clf.score(data, bin_target), 1.0, 2)
开发者ID:FedericoV,项目名称:lightning,代码行数:7,代码来源:test_fista.py

示例6: test_fista_multiclass_l1_no_line_search

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def test_fista_multiclass_l1_no_line_search():
    for data in (mult_dense, mult_csr):
        clf = FistaClassifier(max_iter=500, penalty="l1", multiclass=True,
                              max_steps=0)
        clf.fit(data, mult_target)
        assert_almost_equal(clf.score(data, mult_target), 0.95, 2)
开发者ID:FedericoV,项目名称:lightning,代码行数:8,代码来源:test_fista.py

示例7: test_fista_multiclass_l1l2_log_margin

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def test_fista_multiclass_l1l2_log_margin():
    for data in (mult_dense, mult_csr):
        clf = FistaClassifier(max_iter=200, penalty="l1/l2", loss="log_margin",
                              multiclass=True)
        clf.fit(data, mult_target)
        assert_almost_equal(clf.score(data, mult_target), 0.95, 2)
开发者ID:FedericoV,项目名称:lightning,代码行数:8,代码来源:test_fista.py

示例8: rank

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def rank(M, eps=1e-9):
    U, s, V = svd(M, full_matrices=False)
    return np.sum(s > eps)


bunch = fetch_20newsgroups_vectorized(subset="train")
X_train = bunch.data
y_train = bunch.target

# Reduces dimensionality to make the example faster
ch2 = SelectKBest(chi2, k=5000)
X_train = ch2.fit_transform(X_train, y_train)

bunch = fetch_20newsgroups_vectorized(subset="test")
X_test = bunch.data
y_test = bunch.target
X_test = ch2.transform(X_test)

clf = FistaClassifier(C=1.0 / X_train.shape[0],
                      max_iter=200,
                      penalty="trace",
                      multiclass=True)

for alpha in (1e-3, 1e-2, 0.1, 0.2, 0.3):
    print("alpha=", alpha)
    clf.alpha = alpha
    clf.fit(X_train, y_train)
    print(clf.score(X_test, y_test))
    print(rank(clf.coef_))
开发者ID:DEVESHTARASIA,项目名称:lightning,代码行数:31,代码来源:trace.py

示例9: fetch_20newsgroups_vectorized

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
import time

import numpy as np

from sklearn.datasets import fetch_20newsgroups_vectorized
from lightning.classification import FistaClassifier

bunch = fetch_20newsgroups_vectorized(subset="all")
X = bunch.data
y = bunch.target
y[y >= 1] = 1

clf = FistaClassifier(C=1.0 / X.shape[0], alpha=1e-5, max_iter=200)
start = time.time()
clf.fit(X, y)

print "Training time", time.time() - start
print "Accuracy", np.mean(clf.predict(X) == y)
print "% non-zero", clf.n_nonzero(percentage=True)
开发者ID:pandasasa,项目名称:lightning,代码行数:21,代码来源:bench_fista.py

示例10: test_fista_multiclass_classes

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def test_fista_multiclass_classes():
    clf = FistaClassifier()
    clf.fit(mult_dense, mult_target)
    assert_equal(list(clf.classes_), [0, 1, 2])
开发者ID:evgchz,项目名称:lightning,代码行数:6,代码来源:test_fista.py

示例11: test_fista_bin_classes

# 需要导入模块: from lightning.classification import FistaClassifier [as 别名]
# 或者: from lightning.classification.FistaClassifier import fit [as 别名]
def test_fista_bin_classes():
    clf = FistaClassifier()
    clf.fit(bin_dense, bin_target)
    assert_equal(list(clf.classes_), [0, 1])
开发者ID:evgchz,项目名称:lightning,代码行数:6,代码来源:test_fista.py


注:本文中的lightning.classification.FistaClassifier.fit方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。