当前位置: 首页>>代码示例>>Python>>正文


Python validation.check_random_state函数代码示例

本文整理汇总了Python中sklearn.utils.validation.check_random_state函数的典型用法代码示例。如果您正苦于以下问题:Python check_random_state函数的具体用法?Python check_random_state怎么用?Python check_random_state使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了check_random_state函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

    def __init__(self, configuration, random_state=None):
        self.configuration = configuration

        if random_state is None:
            self.random_state = check_random_state(1)
        else:
            self.random_state = check_random_state(random_state)
开发者ID:stokasto,项目名称:auto-sklearn,代码行数:7,代码来源:base.py

示例2: test_auc_score_non_binary_class

def test_auc_score_non_binary_class():
    # Test that roc_auc_score function returns an error when trying
    # to compute AUC for non-binary class values.
    rng = check_random_state(404)
    y_pred = rng.rand(10)
    # y_true contains only one class value
    y_true = np.zeros(10, dtype="int")
    assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred)
    y_true = np.ones(10, dtype="int")
    assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred)
    y_true = -np.ones(10, dtype="int")
    assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred)
    # y_true contains three different class values
    y_true = rng.randint(0, 3, size=10)
    assert_raise_message(ValueError, "multiclass format is not supported", roc_auc_score, y_true, y_pred)

    clean_warning_registry()
    with warnings.catch_warnings(record=True):
        rng = check_random_state(404)
        y_pred = rng.rand(10)
        # y_true contains only one class value
        y_true = np.zeros(10, dtype="int")
        assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred)
        y_true = np.ones(10, dtype="int")
        assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred)
        y_true = -np.ones(10, dtype="int")
        assert_raise_message(ValueError, "ROC AUC score is not defined", roc_auc_score, y_true, y_pred)

        # y_true contains three different class values
        y_true = rng.randint(0, 3, size=10)
        assert_raise_message(ValueError, "multiclass format is not supported", roc_auc_score, y_true, y_pred)
开发者ID:jonathanwoodard,项目名称:scikit-learn,代码行数:31,代码来源:test_ranking.py

示例3: test_sample_weight_invariance

def test_sample_weight_invariance(n_samples=50):
    random_state = check_random_state(0)

    # binary
    random_state = check_random_state(0)
    y_true = random_state.randint(0, 2, size=(n_samples, ))
    y_pred = random_state.randint(0, 2, size=(n_samples, ))
    y_score = random_state.random_sample(size=(n_samples,))
    for name in ALL_METRICS:
        if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
                name in METRIC_UNDEFINED_BINARY):
            continue
        metric = ALL_METRICS[name]
        if name in THRESHOLDED_METRICS:
            yield _named_check(check_sample_weight_invariance, name), name,\
                  metric, y_true, y_score
        else:
            yield _named_check(check_sample_weight_invariance, name), name,\
                  metric, y_true, y_pred

    # multiclass
    random_state = check_random_state(0)
    y_true = random_state.randint(0, 5, size=(n_samples, ))
    y_pred = random_state.randint(0, 5, size=(n_samples, ))
    y_score = random_state.random_sample(size=(n_samples, 5))
    for name in ALL_METRICS:
        if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
                name in METRIC_UNDEFINED_BINARY_MULTICLASS):
            continue
        metric = ALL_METRICS[name]
        if name in THRESHOLDED_METRICS:
            yield _named_check(check_sample_weight_invariance, name), name,\
                  metric, y_true, y_score
        else:
            yield _named_check(check_sample_weight_invariance, name), name,\
                  metric, y_true, y_pred

    # multilabel indicator
    _, ya = make_multilabel_classification(n_features=1, n_classes=20,
                                           random_state=0, n_samples=100,
                                           allow_unlabeled=False)
    _, yb = make_multilabel_classification(n_features=1, n_classes=20,
                                           random_state=1, n_samples=100,
                                           allow_unlabeled=False)
    y_true = np.vstack([ya, yb])
    y_pred = np.vstack([ya, ya])
    y_score = random_state.randint(1, 4, size=y_true.shape)

    for name in (MULTILABELS_METRICS + THRESHOLDED_MULTILABEL_METRICS +
                 MULTIOUTPUT_METRICS):
        if name in METRICS_WITHOUT_SAMPLE_WEIGHT:
            continue

        metric = ALL_METRICS[name]
        if name in THRESHOLDED_METRICS:
            yield (_named_check(check_sample_weight_invariance, name), name,
                   metric, y_true, y_score)
        else:
            yield (_named_check(check_sample_weight_invariance, name), name,
                   metric, y_true, y_pred)
开发者ID:Allenw3u,项目名称:scikit-learn,代码行数:60,代码来源:test_common.py

示例4: _predict_interval

 def _predict_interval(self, possible_intervals, rng=None):
     if self.method == "center":
         return possible_intervals[len(possible_intervals) / 2]
     elif self.method == "random":
         if rng is None:
             rng = check_random_state(self.random_state)
         return possible_intervals[rng.randint(len(possible_intervals))]
开发者ID:vene,项目名称:ambra,代码行数:7,代码来源:classifiers.py

示例5: __init__

 def __init__(self, shuffle_factor=0.05, not_shuffled_columns=None, random_state=None):
     self.shuffle_factor = shuffle_factor
     self.random_state = check_random_state(random_state)
     if not_shuffled_columns is None:
         self.not_shuffled_columns = []
     else:
         self.not_shuffled_columns = not_shuffled_columns
开发者ID:remenska,项目名称:lhcb_trigger_ml,代码行数:7,代码来源:transformations.py

示例6: predict

 def predict(self, X, Y_possible):
     if self.method == "random":
         rng = check_random_state(self.random_state)
     else:
         rng = None
     return [self._predict_interval(possible_intervals, rng)
             for possible_intervals in Y_possible]
开发者ID:vene,项目名称:ambra,代码行数:7,代码来源:classifiers.py

示例7: test_iris

	def test_iris(self):
		"""Check consistency on dataset iris."""

		# also load the iris dataset
		# and randomly permute it
		iris = datasets.load_iris()
		rng = check_random_state(0)
		perm = rng.permutation(iris.target.size)
		iris.data = iris.data[perm]
		iris.target = iris.target[perm]

		

		clf = CFClassifier("")
		clf.fit(iris.data, iris.target)

		self.assertTrue(os.path.isfile(clf.forest))

		preds = clf.predict(iris.data)


		predicted_ratio = float(np.sum(preds==iris.target))/float(len(iris.target))
		print predicted_ratio

		self.assertGreaterEqual(predicted_ratio, .97) 

		probs = clf.predict_proba(iris.data)


		bin_idx=iris.target!=2

		roc_auc = roc_auc_score(iris.target[bin_idx], probs[bin_idx,1])

		self.assertGreaterEqual(roc_auc, .97) 
开发者ID:0x0all,项目名称:CloudForest,代码行数:34,代码来源:test_CFClassifier.py

示例8: fit

    def fit(self, x, y):

        if len(set(y)) > 2:
            raise NotImplementedError('Currently MELM supports only binary datasets')

        self.base_objective = DCS_kd(gamma=self.gamma, k=self.k, 
                                     covariance_estimator=self.covariance_estimator)

        if self.classifier == 'KDE':
            self.clf = KDE(gamma=self.gamma)
        elif self.classifier == 'SVM':
            self.clf = SVM()
        elif self.classifier == 'KNN':
            self.clf = KNN()
        else:
            raise NotImplementedError('%s classifier is not implemented' % self.classifier)

        random_state = check_random_state(self.random_state)

        self.a = min(y)
        self.b = max(y)

        self.classes_ = np.array([self.a, self.b])

        self.w = self._find_best_w(x, y, random_state)

        self.clf.fit(self.transform(x), y)
开发者ID:codeaudit,项目名称:melm,代码行数:27,代码来源:melm.py

示例9: test_symmetry

def test_symmetry():
    # Test the symmetry of score and loss functions
    random_state = check_random_state(0)
    y_true = random_state.randint(0, 2, size=(20, ))
    y_pred = random_state.randint(0, 2, size=(20, ))

    # We shouldn't forget any metrics
    assert_equal(SYMMETRIC_METRICS.union(
        NOT_SYMMETRIC_METRICS, set(THRESHOLDED_METRICS),
        METRIC_UNDEFINED_BINARY_MULTICLASS),
        set(ALL_METRICS))

    assert_equal(
        SYMMETRIC_METRICS.intersection(NOT_SYMMETRIC_METRICS),
        set([]))

    # Symmetric metric
    for name in SYMMETRIC_METRICS:
        metric = ALL_METRICS[name]
        assert_allclose(metric(y_true, y_pred), metric(y_pred, y_true),
                        err_msg="%s is not symmetric" % name)

    # Not symmetric metrics
    for name in NOT_SYMMETRIC_METRICS:
        metric = ALL_METRICS[name]

        # use context manager to supply custom error message
        with assert_raises(AssertionError) as cm:
            assert_array_equal(metric(y_true, y_pred), metric(y_pred, y_true))
            cm.msg = ("%s seems to be symmetric" % name)
开发者ID:SuryodayBasak,项目名称:scikit-learn,代码行数:30,代码来源:test_common.py

示例10: check_importances

def check_importances(name, criterion, X, y):
    ForestEstimator = FOREST_ESTIMATORS[name]

    est = ForestEstimator(n_estimators=20, criterion=criterion, random_state=0)
    est.fit(X, y)
    importances = est.feature_importances_
    n_important = np.sum(importances > 0.1)
    assert_equal(importances.shape[0], 10)
    assert_equal(n_important, 3)

    # Check with parallel
    importances = est.feature_importances_
    est.set_params(n_jobs=2)
    importances_parrallel = est.feature_importances_
    assert_array_almost_equal(importances, importances_parrallel)

    # Check with sample weights
    sample_weight = check_random_state(0).randint(1, 10, len(X))
    est = ForestEstimator(n_estimators=20, random_state=0, criterion=criterion)
    est.fit(X, y, sample_weight=sample_weight)
    importances = est.feature_importances_
    assert_true(np.all(importances >= 0.0))

    for scale in [0.5, 10, 100]:
        est = ForestEstimator(n_estimators=20, random_state=0, criterion=criterion)
        est.fit(X, y, sample_weight=scale * sample_weight)
        importances_bis = est.feature_importances_
        assert_less(np.abs(importances - importances_bis).mean(), 0.001)
开发者ID:nelson-liu,项目名称:scikit-learn,代码行数:28,代码来源:test_forest.py

示例11: test_RadiusNeighborsRegressor_multioutput_with_uniform_weight

def test_RadiusNeighborsRegressor_multioutput_with_uniform_weight():
    """Test radius neighbors in multi-output regression (uniform weight)"""

    rng = check_random_state(0)
    n_features = 5
    n_samples = 40
    n_output = 4

    X = rng.rand(n_samples, n_features)
    y = rng.rand(n_samples, n_output)
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

    for algorithm, weights in product(ALGORITHMS, [None, 'uniform']):

        rnn = neighbors. RadiusNeighborsRegressor(weights=weights,
                                                  algorithm=algorithm)
        rnn.fit(X_train, y_train)

        neigh_idx = rnn.radius_neighbors(X_test, return_distance=False)
        y_pred_idx = np.array([np.mean(y_train[idx], axis=0)
                               for idx in neigh_idx])

        y_pred_idx = np.array(y_pred_idx)
        y_pred = rnn.predict(X_test)

        assert_equal(y_pred_idx.shape, y_test.shape)
        assert_equal(y_pred.shape, y_test.shape)
        assert_array_almost_equal(y_pred, y_pred_idx)
开发者ID:93sam,项目名称:scikit-learn,代码行数:28,代码来源:test_neighbors.py

示例12: test_thresholded_invariance_string_vs_numbers_labels

def test_thresholded_invariance_string_vs_numbers_labels(name):
    # Ensure that thresholded metrics with string labels are invariant
    random_state = check_random_state(0)
    y1 = random_state.randint(0, 2, size=(20, ))
    y2 = random_state.randint(0, 2, size=(20, ))

    y1_str = np.array(["eggs", "spam"])[y1]

    pos_label_str = "spam"

    with ignore_warnings():
        metric = THRESHOLDED_METRICS[name]
        if name not in METRIC_UNDEFINED_BINARY:
            # Ugly, but handle case with a pos_label and label
            metric_str = metric
            if name in METRICS_WITH_POS_LABEL:
                metric_str = partial(metric_str, pos_label=pos_label_str)

            measure_with_number = metric(y1, y2)
            measure_with_str = metric_str(y1_str, y2)
            assert_array_equal(measure_with_number, measure_with_str,
                               err_msg="{0} failed string vs number "
                                       "invariance test".format(name))

            measure_with_strobj = metric_str(y1_str.astype('O'), y2)
            assert_array_equal(measure_with_number, measure_with_strobj,
                               err_msg="{0} failed string object vs number "
                                       "invariance test".format(name))
        else:
            # TODO those metrics doesn't support string label yet
            assert_raises(ValueError, metric, y1_str, y2)
            assert_raises(ValueError, metric, y1_str.astype('O'), y2)
开发者ID:allefpablo,项目名称:scikit-learn,代码行数:32,代码来源:test_common.py

示例13: test_binary_clf_curve

def test_binary_clf_curve():
    rng = check_random_state(404)
    y_true = rng.randint(0, 3, size=10)
    y_pred = rng.rand(10)
    msg = "multiclass format is not supported"
    assert_raise_message(ValueError, msg, precision_recall_curve,
                         y_true, y_pred)
开发者ID:allefpablo,项目名称:scikit-learn,代码行数:7,代码来源:test_ranking.py

示例14: endless_permutations

def endless_permutations(N, random_state=None):
    """
    Generate an endless sequence of random integers from permutations of the
    set [0, ..., N).

    If we call this N times, we will sweep through the entire set without
    replacement, on the (N+1)th call a new permutation will be created, etc.

    Parameters
    ----------
    N: int
        the length of the set
    random_state: int or RandomState, optional
        random seed

    Yields
    ------
    int:
        a random int from the set [0, ..., N)
    """
    generator = check_random_state(random_state)
    while True:
        batch_inds = generator.permutation(N)
        for b in batch_inds:
            yield b
开发者ID:NICTA,项目名称:revrand,代码行数:25,代码来源:rand.py

示例15: check_alternative_lrap_implementation

def check_alternative_lrap_implementation(lrap_score, n_classes=5,
                                          n_samples=20, random_state=0):
    _, y_true = make_multilabel_classification(n_features=1,
                                               allow_unlabeled=False,
                                               random_state=random_state,
                                               n_classes=n_classes,
                                               n_samples=n_samples)

    # Score with ties
    y_score = sparse_random_matrix(n_components=y_true.shape[0],
                                   n_features=y_true.shape[1],
                                   random_state=random_state)

    if hasattr(y_score, "toarray"):
        y_score = y_score.toarray()
    score_lrap = label_ranking_average_precision_score(y_true, y_score)
    score_my_lrap = _my_lrap(y_true, y_score)
    assert_almost_equal(score_lrap, score_my_lrap)

    # Uniform score
    random_state = check_random_state(random_state)
    y_score = random_state.uniform(size=(n_samples, n_classes))
    score_lrap = label_ranking_average_precision_score(y_true, y_score)
    score_my_lrap = _my_lrap(y_true, y_score)
    assert_almost_equal(score_lrap, score_my_lrap)
开发者ID:BTY2684,项目名称:scikit-learn,代码行数:25,代码来源:test_ranking.py


注:本文中的sklearn.utils.validation.check_random_state函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。