当前位置: 首页>>代码示例>>Python>>正文


Python Classifier.predict方法代码示例

本文整理汇总了Python中weka.classifiers.Classifier.predict方法的典型用法代码示例。如果您正苦于以下问题:Python Classifier.predict方法的具体用法?Python Classifier.predict怎么用?Python Classifier.predict使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在weka.classifiers.Classifier的用法示例。


在下文中一共展示了Classifier.predict方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: WekaRandomForestClassifier

# 需要导入模块: from weka.classifiers import Classifier [as 别名]
# 或者: from weka.classifiers.Classifier import predict [as 别名]
class WekaRandomForestClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, n_estimators=10,
                       max_depth=None,
                       max_features="auto",
                       random_state=None):
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.max_features = max_features
        self.random_state = random_state

    def fit(self, X, y):
        # Check params
        self.n_features_ = X.shape[1]
        random_state = check_random_state(self.random_state)

        if isinstance(self.max_features, str):
            if self.max_features == "auto":
                max_features = max(1, int(np.sqrt(self.n_features_)))
            elif self.max_features == "sqrt":
                max_features = max(1, int(np.sqrt(self.n_features_)))
            elif self.max_features == "log2":
                max_features = max(1, int(np.log2(self.n_features_)))
            else:
                raise ValueError(
                    'Invalid value for max_features. Allowed string '
                    'values are "auto", "sqrt" or "log2".')
        elif self.max_features is None:
            max_features = self.n_features_
        elif isinstance(self.max_features, (numbers.Integral, np.integer)):
            max_features = self.max_features
        else:  # float
            max_features = int(self.max_features * self.n_features_)

        params = {}
        params["-I"] = self.n_estimators
        params["-K"] = max_features
        params["-depth"] = 0 if self.max_depth is None else self.max_depth
        params["-no-cv"] = None
        params["-s"] = random_state.randint(1000000)

        # Convert data
        self.classes_ = np.unique(y)
        self.n_classes_ = len(self.classes_)
        y = np.searchsorted(self.classes_, y)

        tf = tempfile.NamedTemporaryFile(mode="w", suffix=".arff", dir="/dev/shm", delete=False)
        to_arff(X, y, self.n_classes_, tf)
        tf.close()

        # Run
        self.model_ = Classifier(name="weka.classifiers.trees.RandomForest", ckargs=params)
        self.model_.train(tf.name)
        os.remove(tf.name)

        return self

    def predict(self, X):
        tf = tempfile.NamedTemporaryFile(mode="w", suffix=".arff", dir="/dev/shm", delete=False)
        to_arff(X, None, self.n_classes_, tf)
        tf.close()

        pred = np.zeros(len(X), dtype=np.int32)

        for i, r in enumerate(self.model_.predict(tf.name)):
            pred[i] = int(r.predicted[5])

        os.remove(tf.name)

        return self.classes_[pred]
开发者ID:ASBoldt,项目名称:phd-thesis,代码行数:71,代码来源:wrapper.py

示例2: test_IBk

# 需要导入模块: from weka.classifiers import Classifier [as 别名]
# 或者: from weka.classifiers.Classifier import predict [as 别名]
    def test_IBk(self):
        
        # Train a classifier.
        print('Training IBk classifier...')
        c = Classifier(name='weka.classifiers.lazy.IBk', ckargs={'-K':1})
        training_fn = os.path.join(BP, 'fixtures/abalone-train.arff')
        c.train(training_fn, verbose=1)
        self.assertTrue(c._model_data)
        
        # Make a valid query.
        print('Using IBk classifier...')
        query_fn = os.path.join(BP, 'fixtures/abalone-query.arff')
        predictions = list(c.predict(query_fn, verbose=1, cleanup=0))
        pred0 = predictions[0]
        print('pred0:', pred0)
        pred1 = PredictionResult(actual=None, predicted=7, probability=None)
        print('pred1:', pred1)
        self.assertEqual(pred0, pred1)
            
        # Make a valid query.
        with self.assertRaises(PredictionError):
            query_fn = os.path.join(BP, 'fixtures/abalone-query-bad.arff')
            predictions = list(c.predict(query_fn, verbose=1, cleanup=0))
            
        # Make a valid query manually.
        query = arff.ArffFile(relation='test', schema=[
            ('Sex', ('M', 'F', 'I')),
            ('Length', 'numeric'),
            ('Diameter', 'numeric'),
            ('Height', 'numeric'),
            ('Whole weight', 'numeric'),
            ('Shucked weight', 'numeric'),
            ('Viscera weight', 'numeric'),
            ('Shell weight', 'numeric'),
            ('Class_Rings', 'integer'),
        ])
        query.append(['M', 0.35, 0.265, 0.09, 0.2255, 0.0995, 0.0485, 0.07, '?'])
        data_str0 = """% 
@relation test
@attribute 'Sex' {F,I,M}
@attribute 'Length' numeric
@attribute 'Diameter' numeric
@attribute 'Height' numeric
@attribute 'Whole weight' numeric
@attribute 'Shucked weight' numeric
@attribute 'Viscera weight' numeric
@attribute 'Shell weight' numeric
@attribute 'Class_Rings' integer
@data
M,0.35,0.265,0.09,0.2255,0.0995,0.0485,0.07,?
"""
        data_str1 = query.write(fmt=DENSE)
#        print(data_str0
#        print(data_str1
        self.assertEqual(data_str0, data_str1)
        predictions = list(c.predict(query, verbose=1, cleanup=0))
        self.assertEqual(predictions[0],
            PredictionResult(actual=None, predicted=7, probability=None))
        
        # Test pickling.
        fn = os.path.join(BP, 'fixtures/IBk.pkl')
        c.save(fn)
        c = Classifier.load(fn)
        predictions = list(c.predict(query, verbose=1, cleanup=0))
        self.assertEqual(predictions[0],
            PredictionResult(actual=None, predicted=7, probability=None))
        #print('Pickle verified.')
        
        # Make a valid dict query manually.
        query = arff.ArffFile(relation='test', schema=[
            ('Sex', ('M', 'F', 'I')),
            ('Length', 'numeric'),
            ('Diameter', 'numeric'),
            ('Height', 'numeric'),
            ('Whole weight', 'numeric'),
            ('Shucked weight', 'numeric'),
            ('Viscera weight', 'numeric'),
            ('Shell weight', 'numeric'),
            ('Class_Rings', 'integer'),
        ])
        query.append({
            'Sex': 'M',
            'Length': 0.35,
            'Diameter': 0.265,
            'Height': 0.09,
            'Whole weight': 0.2255,
            'Shucked weight': 0.0995,
            'Viscera weight': 0.0485,
            'Shell weight': 0.07,
            'Class_Rings': arff.MISSING,
        })
        predictions = list(c.predict(query, verbose=1, cleanup=0))
        self.assertEqual(predictions[0],
            PredictionResult(actual=None, predicted=7, probability=None))
开发者ID:chrisspen,项目名称:weka,代码行数:96,代码来源:tests.py


注:本文中的weka.classifiers.Classifier.predict方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。