当前位置: 首页>>代码示例>>Python>>正文


Python datasets.fetch_20newsgroups方法代码示例

本文整理汇总了Python中sklearn.datasets.fetch_20newsgroups方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.fetch_20newsgroups方法的具体用法?Python datasets.fetch_20newsgroups怎么用?Python datasets.fetch_20newsgroups使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在sklearn.datasets的用法示例。


在下文中一共展示了datasets.fetch_20newsgroups方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_validate_sklearn_sgd_with_text_cv

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def test_validate_sklearn_sgd_with_text_cv(self):
        categories = ['alt.atheism','talk.religion.misc']
        data = fetch_20newsgroups(subset='train', categories=categories)
        X = data.data[:4]
        Y = data.target[:4]
        features = ['input']
        target = 'output'
        model = SGDClassifier(loss="log")
        file_name = model.__class__.__name__ + '_CountVec_.pmml'
        pipeline = Pipeline([
            ('vect', CountVectorizer()),
            ('clf', model)
        ])
        pipeline.fit(X, Y)
        skl_to_pmml(pipeline, features , target, file_name)
        self.assertEqual(self.schema.is_valid(file_name), True) 
开发者ID:nyoka-pmml,项目名称:nyoka,代码行数:18,代码来源:_validateSchema.py

示例2: __init__

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def __init__(self,
                 cache: bool = False,
                 transform: Dict[str, Union[Field, Dict]] = None) -> None:
        """Initialize the NewsGroupDataset builtin."""
        try:
            from sklearn.datasets import fetch_20newsgroups
        except ImportError:
            raise ImportError("Install sklearn to use the NewsGroupDataset")

        train = fetch_20newsgroups(subset='train')
        test = fetch_20newsgroups(subset='test')

        train = [(' '.join(d.split()), str(t)) for d, t in zip(train['data'], train['target'])]
        test = [(' '.join(d.split()), str(t)) for d, t in zip(test['data'], test['target'])]

        named_cols = ['text', 'label']
        super().__init__(
            train=train,
            val=None,
            test=test,
            cache=cache,
            named_columns=named_cols,
            transform=transform
        ) 
开发者ID:asappresearch,项目名称:flambe,代码行数:26,代码来源:datasets.py

示例3: load

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def load(self):
        categories = ['comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware']
        newsgroups_train = fetch_20newsgroups(
            subset='train', remove=('headers', 'footers', 'quotes'), categories=categories)
        newsgroups_test = fetch_20newsgroups(
            subset='test', remove=('headers', 'footers', 'quotes'), categories=categories)
        vectorizer = TfidfVectorizer(stop_words='english', min_df=0.001, max_df=0.20)
        vectors = vectorizer.fit_transform(newsgroups_train.data)
        vectors_test = vectorizer.transform(newsgroups_test.data)
        x1 = vectors
        y1 = newsgroups_train.target
        x2 = vectors_test
        y2 = newsgroups_test.target
        x = np.array(np.r_[x1.todense(), x2.todense()])
        y = np.r_[y1, y2]
        return x, y 
开发者ID:sato9hara,项目名称:sgd-influence,代码行数:18,代码来源:DataModule.py

示例4: _te_ss_t_build

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def _te_ss_t_build(self):
		from sklearn.datasets import fetch_20newsgroups
		from sklearn.feature_extraction.text import CountVectorizer

		newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))
		count_vectorizer = CountVectorizer()
		X_counts = count_vectorizer.fit_transform(newsgroups_train.data)
		corpus = CorpusFromScikit(
			X=X_counts,
			y=newsgroups_train.target,
			feature_vocabulary=count_vectorizer.vocabulary_,
			category_names=newsgroups_train.target_names,
			raw_texts=newsgroups_train.data
		).build()
		self.assertEqual(corpus.get_categories()[:2], ['alt.atheism', 'comp.graphics'])
		self.assertEqual(corpus
		                 .get_term_freq_df()
		                 .assign(score=corpus.get_scaled_f_scores('alt.atheism'))
		                 .sort_values(by='score', ascending=False).index.tolist()[:5],
		                 ['atheism', 'atheists', 'islam', 'atheist', 'belief'])
		self.assertGreater(len(corpus.get_texts()[0]), 5) 
开发者ID:JasonKessler,项目名称:scattertext,代码行数:23,代码来源:test_corpusFromScikit.py

示例5: test_MinHashEncoder

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def test_MinHashEncoder(n_sample=70, minmax_hash=False):
    X_txt = fetch_20newsgroups(subset='train')['data']
    X = X_txt[:n_sample]

    for minmax_hash in [True, False]:
        for hashing in ['fast', 'murmur']:

            if minmax_hash and hashing == 'murmur':
                pass # not implemented

            # Test output shape
            encoder = MinHashEncoder(n_components=50, hashing=hashing)
            encoder.fit(X)
            y = encoder.transform(X)
            assert y.shape == (n_sample, 50), str(y.shape)
            assert len(set(y[0])) == 50

            # Test same seed return the same output
            encoder = MinHashEncoder(50, hashing=hashing)
            encoder.fit(X)
            y2 = encoder.transform(X)
            np.testing.assert_array_equal(y, y2)

            # Test min property
            if not minmax_hash:
                X_substring = [x[:x.find(' ')] for x in X]
                encoder = MinHashEncoder(50, hashing=hashing)
                encoder.fit(X_substring)
                y_substring = encoder.transform(X_substring)
                np.testing.assert_array_less(y - y_substring, 0.0001) 
开发者ID:dirty-cat,项目名称:dirty_cat,代码行数:32,代码来源:test_minhash_encoder.py

示例6: test_validate_sklearn_sgd_with_text

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def test_validate_sklearn_sgd_with_text(self):
        categories = ['alt.atheism','talk.religion.misc']
        data = fetch_20newsgroups(subset='train', categories=categories)
        X = data.data[:4]
        Y = data.target[:4]
        features = ['input']
        target = 'output'
        model = SGDClassifier(loss="log")
        file_name = model.__class__.__name__ + '_TfIdfVec_.pmml'
        pipeline = Pipeline([
            ('vect', TfidfVectorizer()),
            ('clf', model)
        ])
        pipeline.fit(X, Y)
        skl_to_pmml(pipeline, features , target, file_name)
        self.assertEqual(self.schema.is_valid(file_name), True) 
开发者ID:nyoka-pmml,项目名称:nyoka,代码行数:18,代码来源:_validateSchema.py

示例7: ng

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def ng(partitions=['train', 'test']):
  '''loads 20 NewsGroups topic classification dataset
  Args:
    partitions: component(s) of data to load; can be a string (for one partition) or list of strings
  Returns:
    ((list of documents, list of labels) for each partition)
  '''

  if type(partitions) == str:
    data = fetch_20newsgroups(subset=partitions)
    return data['data'], list(data['target'])
  output = []
  for partition in partitions:
    data = fetch_20newsgroups(subset=partition)
    output.append((data['data'], list(data['target'])))
  return output 
开发者ID:NLPrinceton,项目名称:text_embedding,代码行数:18,代码来源:documents.py

示例8: create_binary_newsgroups_data

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def create_binary_newsgroups_data():
    categories = ["alt.atheism", "soc.religion.christian"]
    newsgroups_train = fetch_20newsgroups(subset="train", categories=categories)
    newsgroups_test = fetch_20newsgroups(subset="test", categories=categories)
    class_names = ["atheism", "christian"]
    return newsgroups_train, newsgroups_test, class_names 
开发者ID:interpretml,项目名称:interpret-text,代码行数:8,代码来源:common_utils.py

示例9: fetch_data

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def fetch_data(path):
    from sklearn.datasets import fetch_20newsgroups
    categories = ['comp.graphics', 'rec.sport.baseball', 'talk.politics.guns']
    dataset = fetch_20newsgroups(path, categories=categories)
    return dataset 
开发者ID:thunlp,项目名称:OpenNE,代码行数:7,代码来源:20newsgroup.py

示例10: test_fast_hash

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def test_fast_hash():

    from sklearn import datasets
    data = datasets.fetch_20newsgroups()
    a = data.data[0]

    min_hash = ngram_min_hash(a, seed=0)
    min_hash2 = ngram_min_hash(a, seed=0)
    assert min_hash == min_hash2

    list_min_hash = [ngram_min_hash(a, seed=seed) for seed in range(50)]
    assert len(set(list_min_hash)) > 45, 'Too many hash collisions'

    min_hash4 = ngram_min_hash(a, seed=0, return_minmax=True)
    assert len(min_hash4) == 2 
开发者ID:dirty-cat,项目名称:dirty_cat,代码行数:17,代码来源:test_fast_hash.py

示例11: test_20news

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def test_20news():
    try:
        data = datasets.fetch_20newsgroups(
            subset='all', download_if_missing=False, shuffle=False)
    except IOError:
        raise SkipTest("Download 20 newsgroups to run this test")

    # Extract a reduced dataset
    data2cats = datasets.fetch_20newsgroups(
        subset='all', categories=data.target_names[-1:-3:-1], shuffle=False)
    # Check that the ordering of the target_names is the same
    # as the ordering in the full dataset
    assert_equal(data2cats.target_names,
                 data.target_names[-2:])
    # Assert that we have only 0 and 1 as labels
    assert_equal(np.unique(data2cats.target).tolist(), [0, 1])

    # Check that the number of filenames is consistent with data/target
    assert_equal(len(data2cats.filenames), len(data2cats.target))
    assert_equal(len(data2cats.filenames), len(data2cats.data))

    # Check that the first entry of the reduced dataset corresponds to
    # the first entry of the corresponding category in the full dataset
    entry1 = data2cats.data[0]
    category = data2cats.target_names[data2cats.target[0]]
    label = data.target_names.index(category)
    entry2 = data.data[np.where(data.target == label)[0][0]]
    assert_equal(entry1, entry2) 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:30,代码来源:test_20news.py

示例12: test_20news_length_consistency

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def test_20news_length_consistency():
    """Checks the length consistencies within the bunch

    This is a non-regression test for a bug present in 0.16.1.
    """
    try:
        data = datasets.fetch_20newsgroups(
            subset='all', download_if_missing=False, shuffle=False)
    except IOError:
        raise SkipTest("Download 20 newsgroups to run this test")
    # Extract the full dataset
    data = datasets.fetch_20newsgroups(subset='all')
    assert_equal(len(data['data']), len(data.data))
    assert_equal(len(data['target']), len(data.target))
    assert_equal(len(data['filenames']), len(data.filenames)) 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:17,代码来源:test_20news.py

示例13: test_20news_vectorized

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def test_20news_vectorized():
    try:
        datasets.fetch_20newsgroups(subset='all',
                                    download_if_missing=False)
    except IOError:
        raise SkipTest("Download 20 newsgroups to run this test")

    # test subset = train
    bunch = datasets.fetch_20newsgroups_vectorized(subset="train")
    assert sp.isspmatrix_csr(bunch.data)
    assert_equal(bunch.data.shape, (11314, 130107))
    assert_equal(bunch.target.shape[0], 11314)
    assert_equal(bunch.data.dtype, np.float64)

    # test subset = test
    bunch = datasets.fetch_20newsgroups_vectorized(subset="test")
    assert sp.isspmatrix_csr(bunch.data)
    assert_equal(bunch.data.shape, (7532, 130107))
    assert_equal(bunch.target.shape[0], 7532)
    assert_equal(bunch.data.dtype, np.float64)

    # test return_X_y option
    fetch_func = partial(datasets.fetch_20newsgroups_vectorized, subset='test')
    check_return_X_y(bunch, fetch_func)

    # test subset = all
    bunch = datasets.fetch_20newsgroups_vectorized(subset='all')
    assert sp.isspmatrix_csr(bunch.data)
    assert_equal(bunch.data.shape, (11314 + 7532, 130107))
    assert_equal(bunch.target.shape[0], 11314 + 7532)
    assert_equal(bunch.data.dtype, np.float64) 
开发者ID:PacktPublishing,项目名称:Mastering-Elasticsearch-7.0,代码行数:33,代码来源:test_20news.py

示例14: setUp

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def setUp(self):
        """Carga de los datos de prueba (20 Newsgroups corpus)."""
        newsdata = fetch_20newsgroups(data_home="./data/")
        self.ids = [str(i) for i in range(len(newsdata.target))]
        self.texts = newsdata.data
        self.labels = [newsdata.target_names[idx] for idx in newsdata.target]
        self.tc = TextClassifier(self.texts, self.ids) 
开发者ID:datosgobar,项目名称:textar,代码行数:9,代码来源:test_text_classifier.py

示例15: load_newsgroups

# 需要导入模块: from sklearn import datasets [as 别名]
# 或者: from sklearn.datasets import fetch_20newsgroups [as 别名]
def load_newsgroups():
    """20 News Groups Dataset.

    The data of this dataset is a 1d numpy array vector containing the texts
    from 11314 newsgroups posts, and the target is a 1d numpy integer array
    containing the label of one of the 20 topics that they are about.
    """
    dataset = datasets.fetch_20newsgroups()
    return Dataset(load_newsgroups.__doc__, np.array(dataset.data), dataset.target,
                   accuracy_score, stratify=True) 
开发者ID:HDI-Project,项目名称:MLBlocks,代码行数:12,代码来源:datasets.py


注:本文中的sklearn.datasets.fetch_20newsgroups方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。