当前位置: 首页>>代码示例>>C++>>正文


C++ CvDTree::train方法代码示例

本文整理汇总了C++中CvDTree::train方法的典型用法代码示例。如果您正苦于以下问题:C++ CvDTree::train方法的具体用法?C++ CvDTree::train怎么用?C++ CvDTree::train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在CvDTree的用法示例。


在下文中一共展示了CvDTree::train方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。

示例1: Train_tree

void Model::Train_tree( const SampleSet& samples )
{
	CvDTree* model = (CvDTree*)m_pModel;
	CvDTreeParams* para = (CvDTreeParams*)m_trainPara;
	model->train(samples.Samples(), CV_ROW_SAMPLE, samples.Labels(), 
		cv::Mat(), cv::Mat(), cv::Mat(), cv::Mat(), *para);
}
开发者ID:ElmerNing,项目名称:OpencvML,代码行数:7,代码来源:Model.cpp

示例2: mushroom_create_dtree

CvDTree* mushroom_create_dtree( const CvMat* data, const CvMat* missing,
                                const CvMat* responses, float p_weight )
{
    CvDTree* dtree;
    CvMat* var_type;
    int i, hr1 = 0, hr2 = 0, p_total = 0;
    float priors[] = { 1, p_weight };

    var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
    cvSet( var_type, cvScalarAll(CV_VAR_CATEGORICAL) ); // all the variables are categorical

    dtree = new CvDTree;
    
    dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
                  CvDTreeParams( 8, // max depth
                                 10, // min sample count
                                 0, // regression accuracy: N/A here
                                 true, // compute surrogate split, as we have missing data
                                 15, // max number of categories (use sub-optimal algorithm for larger numbers)
                                 10, // the number of cross-validation folds
                                 true, // use 1SE rule => smaller tree
                                 true, // throw away the pruned tree branches
                                 priors // the array of priors, the bigger p_weight, the more attention
                                        // to the poisonous mushrooms
                                        // (a mushroom will be judjed to be poisonous with bigger chance)
                                 ));

    // compute hit-rate on the training database, demonstrates predict usage.
    for( i = 0; i < data->rows; i++ )
    {
        CvMat sample, mask;
        cvGetRow( data, &sample, i );
        cvGetRow( missing, &mask, i );
        double r = dtree->predict( &sample, &mask )->value;
        int d = fabs(r - responses->data.fl[i]) >= FLT_EPSILON;
        if( d )
        {
            if( r != 'p' )
                hr1++;
            else
                hr2++;
        }
        p_total += responses->data.fl[i] == 'p';
    }

    printf( "Results on the training database:\n"
            "\tPoisonous mushrooms mis-predicted: %d (%g%%)\n"
            "\tFalse-alarms: %d (%g%%)\n", hr1, (double)hr1*100/p_total,
            hr2, (double)hr2*100/(data->rows - p_total) );

    cvReleaseMat( &var_type );

    return dtree;
}
开发者ID:AndrewShmig,项目名称:FaceDetect,代码行数:54,代码来源:mushroom.cpp

示例3: main

int main()
{
    const int train_sample_count = 300;

//#define LEPIOTA
#ifdef LEPIOTA
    const char* filename = "../../../OpenCV_SVN/samples/c/agaricus-lepiota.data";
#else
    const char* filename = "../../../OpenCV_SVN/samples/c/waveform.data";
#endif

    CvDTree dtree;
    CvBoost boost;
    CvRTrees rtrees;
    CvERTrees ertrees;

    CvMLData data;

    CvTrainTestSplit spl( train_sample_count );
    
    data.read_csv( filename );

#ifdef LEPIOTA
    data.set_response_idx( 0 );     
#else
    data.set_response_idx( 21 );     
    data.change_var_type( 21, CV_VAR_CATEGORICAL );
#endif

    data.set_train_test_split( &spl );
    
    printf("======DTREE=====\n");
    dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
    print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data ), dtree.get_var_importance() );

#ifdef LEPIOTA
    printf("======BOOST=====\n");
    boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
    print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data ), 0 );
#endif

    printf("======RTREES=====\n");
    rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
    print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data ), rtrees.get_var_importance() );

    printf("======ERTREES=====\n");
    ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
    print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data ), ertrees.get_var_importance() );

    return 0;
}
开发者ID:glo,项目名称:ee384b,代码行数:51,代码来源:tree_engine.cpp

示例4: decisiontree

//Decision Tree
void decisiontree ( Mat & trainingData , Mat & trainingClasses , Mat & testData ,
		Mat & testClasses ) {
	CvDTree dtree ;
	Mat var_type (3 , 1 , CV_8U ) ;
	// define attributes as numerical
	var_type.at < unsigned int >(0 ,0) = CV_VAR_NUMERICAL;
	var_type.at < unsigned int >(0 ,1) = CV_VAR_NUMERICAL ;
	// define output node as numerical
	var_type.at < unsigned int >(0 ,2) = CV_VAR_NUMERICAL;
	dtree.train ( trainingData , CV_ROW_SAMPLE , trainingClasses , Mat () , Mat () ,
			var_type , Mat () , CvDTreeParams () ) ;
	Mat predicted ( testClasses.rows , 1 , CV_32F ) ;
	for ( int i = 0; i < testData.rows ; i ++) {
		const Mat sample = testData.row ( i ) ;
		CvDTreeNode * prediction = dtree.predict ( sample ) ;
		predicted.at < float > (i , 0) = prediction->value ;
	}
	cout << " Accuracy_ { TREE } = " << evaluate ( predicted , testClasses ) << endl ;
	plot_binary ( testData , predicted , " Predictions tree " ) ;
}
开发者ID:theunknowner,项目名称:WebDerm,代码行数:21,代码来源:mlearn.cpp

示例5: find_decision_boundary_DT

static void find_decision_boundary_DT()
{
    img.copyTo( imgDst );

    Mat trainSamples, trainClasses;
    prepare_train_data( trainSamples, trainClasses );

    // learn classifier
    CvDTree  dtree;

    Mat var_types( 1, trainSamples.cols + 1, CV_8UC1, Scalar(CV_VAR_ORDERED) );
    var_types.at<uchar>( trainSamples.cols ) = CV_VAR_CATEGORICAL;

    CvDTreeParams params;
    params.max_depth = 8;
    params.min_sample_count = 2;
    params.use_surrogates = false;
    params.cv_folds = 0; // the number of cross-validation folds
    params.use_1se_rule = false;
    params.truncate_pruned_tree = false;

    dtree.train( trainSamples, CV_ROW_SAMPLE, trainClasses,
                 Mat(), Mat(), var_types, Mat(), params );

    Mat testSample(1, 2, CV_32FC1 );
    for( int y = 0; y < img.rows; y += testStep )
    {
        for( int x = 0; x < img.cols; x += testStep )
        {
            testSample.at<float>(0) = (float)x;
            testSample.at<float>(1) = (float)y;

            int response = (int)dtree.predict( testSample )->value;
            circle( imgDst, Point(x,y), 2, classColors[response], 1 );
        }
    }
}
开发者ID:406089450,项目名称:opencv,代码行数:37,代码来源:points_classifier.cpp

示例6: train

int CV_DTreeTest :: train( int test_case_idx )
{
    int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;
    float REG_ACCURACY = 0;
    bool USE_SURROGATE, IS_PRUNED;
    
    const char* data_name = ((CvFileNode*)cvGetSeqElem( data_sets_names, test_case_idx ))->data.str.ptr;      

    // read validation params
    CvFileStorage* fs = ts->get_file_storage();
    CvFileNode* fnode = cvGetFileNodeByName( fs, 0, "validation" ), *fnode1 = 0;
    fnode = cvGetFileNodeByName( fs, fnode, name );
    fnode = cvGetFileNodeByName( fs, fnode, data_name );
    fnode = cvGetFileNodeByName( fs, fnode, "model_params" );
    fnode1 = cvGetFileNodeByName( fs, fnode, "max_depth" );
    if ( !fnode1 )
    {
        ts->printf( CvTS::LOG, "MAX_DEPTH can not be read from config file" );
        return CvTS::FAIL_INVALID_TEST_DATA;
    }
    MAX_DEPTH = fnode1->data.i;
    fnode1 = cvGetFileNodeByName( fs, fnode, "min_sample_count" );
    if ( !fnode1 )
    {
        ts->printf( CvTS::LOG, "MAX_DEPTH can not be read from config file" );
        return CvTS::FAIL_INVALID_TEST_DATA;
    }
    MIN_SAMPLE_COUNT = fnode1->data.i;
    fnode1 = cvGetFileNodeByName( fs, fnode, "use_surrogate" );
    if ( !fnode1 )
    {
        ts->printf( CvTS::LOG, "USE_SURROGATE can not be read from config file" );
        return CvTS::FAIL_INVALID_TEST_DATA;
    }
    USE_SURROGATE = ( fnode1->data.i!= 0);
    fnode1 = cvGetFileNodeByName( fs, fnode, "max_categories" );
    if ( !fnode1 )
    {
        ts->printf( CvTS::LOG, "MAX_CATEGORIES can not be read from config file" );
        return CvTS::FAIL_INVALID_TEST_DATA;
    }
    MAX_CATEGORIES = fnode1->data.i;
    fnode1 = cvGetFileNodeByName( fs, fnode, "cv_folds" );
    if ( !fnode1 )
    {
        ts->printf( CvTS::LOG, "CV_FOLDS can not be read from config file" );
        return CvTS::FAIL_INVALID_TEST_DATA;
    }
    CV_FOLDS = fnode1->data.i;
    fnode1 = cvGetFileNodeByName( fs, fnode, "is_pruned" );
    if ( !fnode1 )
    {
        ts->printf( CvTS::LOG, "IS_PRUNED can not be read from config file" );
        return CvTS::FAIL_INVALID_TEST_DATA;
    }
    IS_PRUNED = (fnode1->data.i != 0);

    
    if ( !tree->train( &data, 
       CvDTreeParams(MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, USE_SURROGATE,
       MAX_CATEGORIES, CV_FOLDS, false, IS_PRUNED, 0 )) )
    {
        ts->printf( CvTS::LOG, "in test case %d model training  was failed", test_case_idx );
        return CvTS::FAIL_INVALID_OUTPUT;
    }
    return CvTS::OK;
}
开发者ID:glo,项目名称:ee384b,代码行数:67,代码来源:adtree.cpp

示例7: main

int main(int argc, char **argv)
{

	float priors[] = { 1.0, 10.0 };	// Edible vs poisonos weights

	CvMat *var_type;
	CvMat *data;				// jmh add
	data = cvCreateMat(20, 30, CV_8U);	// jmh add

	var_type = cvCreateMat(data->cols + 1, 1, CV_8U);
	cvSet(var_type, cvScalarAll(CV_VAR_CATEGORICAL));	// all these vars 
	// are categorical
	CvDTree *dtree;
	dtree = new CvDTree;
	dtree->train(data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing, CvDTreeParams(8,	// max depth
																						10,	// min sample count
																						0,	// regression accuracy: N/A here
																						true,	// compute surrogate split, 
																						//   as we have missing data
																						15,	// max number of categories 
																						//   (use sub-optimal algorithm for
																						//   larger numbers)
																						10,	// cross-validations 
																						true,	// use 1SE rule => smaller tree
																						true,	// throw away the pruned tree branches
																						priors	// the array of priors, the bigger 
																						//   p_weight, the more attention
																						//   to the poisonous mushrooms
				 )
		);

	dtree->save("tree.xml", "MyTree");
	dtree->clear();
	dtree->load("tree.xml", "MyTree");

#define MAX_CLUSTERS 5
	CvScalar color_tab[MAX_CLUSTERS];
	IplImage *img = cvCreateImage(cvSize(500, 500), 8, 3);
	CvRNG rng = cvRNG(0xffffffff);

	color_tab[0] = CV_RGB(255, 0, 0);
	color_tab[1] = CV_RGB(0, 255, 0);
	color_tab[2] = CV_RGB(100, 100, 255);
	color_tab[3] = CV_RGB(255, 0, 255);
	color_tab[4] = CV_RGB(255, 255, 0);

	cvNamedWindow("clusters", 1);

	for (;;) {
		int k, cluster_count = cvRandInt(&rng) % MAX_CLUSTERS + 1;
		int i, sample_count = cvRandInt(&rng) % 1000 + 1;
		CvMat *points = cvCreateMat(sample_count, 1, CV_32FC2);
		CvMat *clusters = cvCreateMat(sample_count, 1, CV_32SC1);

		/* generate random sample from multivariate 
		   Gaussian distribution */
		for (k = 0; k < cluster_count; k++) {
			CvPoint center;
			CvMat point_chunk;
			center.x = cvRandInt(&rng) % img->width;
			center.y = cvRandInt(&rng) % img->height;
			cvGetRows(points, &point_chunk,
					  k * sample_count / cluster_count,
					  k == cluster_count - 1 ? sample_count :
					  (k + 1) * sample_count / cluster_count);
			cvRandArr(&rng, &point_chunk, CV_RAND_NORMAL,
					  cvScalar(center.x, center.y, 0, 0),
					  cvScalar(img->width / 6, img->height / 6, 0, 0));
		}

		/* shuffle samples */
		for (i = 0; i < sample_count / 2; i++) {
			CvPoint2D32f *pt1 = (CvPoint2D32f *) points->data.fl +
				cvRandInt(&rng) % sample_count;
			CvPoint2D32f *pt2 = (CvPoint2D32f *) points->data.fl +
				cvRandInt(&rng) % sample_count;
			CvPoint2D32f temp;
			CV_SWAP(*pt1, *pt2, temp);
		}

		cvKMeans2(points, cluster_count, clusters,
				  cvTermCriteria(CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 10, 1.0));
		cvZero(img);
		for (i = 0; i < sample_count; i++) {
			CvPoint2D32f pt = ((CvPoint2D32f *) points->data.fl)[i];
			int cluster_idx = clusters->data.i[i];
			cvCircle(img, cvPointFrom32f(pt), 2,
					 color_tab[cluster_idx], CV_FILLED);
		}

		cvReleaseMat(&points);
		cvReleaseMat(&clusters);

		cvShowImage("clusters", img);

		int key = cvWaitKey(0);
		if (key == 27)			// 'ESC'
			break;
	}
}
开发者ID:emcute0319,项目名称:LearningOpenCVCode,代码行数:100,代码来源:ch13_dtree.cpp

示例8: main

int main( int argc, char** argv )
{
	Mat img;
   char file[255];
	
	//total no of training samples
	int total_train_samples = 0;
	for(int cl=0; cl<nr_classes; cl++)
	{
		total_train_samples = total_train_samples + train_samples[cl];
	}
	
	// Training Data
	Mat training_data = Mat(total_train_samples,feature_size,CV_32FC1);
	Mat training_label = Mat(total_train_samples,1,CV_32FC1);
	// training data .csv file
	ofstream trainingDataCSV;
	trainingDataCSV.open("./training_data.csv");	
		
	int index = 0;	
	for(int cl=0; cl<nr_classes; cl++)
	{
      for(int ll=0; ll<train_samples[cl]; ll++)
      {
      	//assign sample label
			training_label.at<float>(index+ll,0) = class_labels[cl]; 	
			//image feature extraction
 			sprintf(file, "%s/%d/%d.png", pathToImages, class_labels[cl], ll);
         img = imread(file, 1);
         if (!img.data)
         {
             cout << "File " << file << " not found\n";
             exit(1);
         }
         imshow("sample",img);
         waitKey(1);
         //calculate feature vector
			vector<float> feature = ColorHistFeature(img);
			for(int ft=0; ft<feature.size(); ft++)
			{
				training_data.at<float>(index+ll,ft) = feature[ft];
				trainingDataCSV<<feature[ft]<<",";
			}
			trainingDataCSV<<class_labels[cl]<<"\n";
		}
		index = index + train_samples[cl];
	}	
	
	trainingDataCSV.close();

	/// Decision Tree
	// Training
	float *priors = NULL;
	CvDTreeParams DTParams = CvDTreeParams(25, // max depth
		                                    5, // min sample count
		                                    0, // regression accuracy: N/A here
		                                    false, // compute surrogate split, no missing data
		                                    15, // max number of categories (use sub-optimal algorithm for larger numbers)
		                                    15, // the number of cross-validation folds
		                                    false, // use 1SE rule => smaller tree
		                                    false, // throw away the pruned tree branches
		                                    priors // the array of priors
		                                   );
	CvDTree DTree;
	DTree.train(training_data,CV_ROW_SAMPLE,training_label,Mat(),Mat(),Mat(),Mat(),DTParams);
			
	// save model
	DTree.save("training.model");		
	
	// load model
	CvDTree DT;
	DT.load("training.model");	
	
	// test on sample image
	string filename = string(pathToImages)+"/test.png";
	Mat test_img = imread(filename.c_str());
	vector<float> test_feature = ColorHistFeature(test_img);
	CvDTreeNode* result_node = DT.predict(Mat(test_feature),Mat(),false);
	double predictedClass = result_node->value;
	cout<<"predictedClass "<<predictedClass<<"\n";

/*	
	//CvMLData for calculating error
	CvMLData* MLData;
	MLData = new CvMLData();
	MLData->read_csv("training_data.csv");
	MLData->set_response_idx(feature_size);
//	MLData->change_var_type(feature_size,CV_VAR_CATEGORICAL);
	
	// calculate training error
	float error = DT.calc_error(MLData,CV_TRAIN_ERROR,0);
	cout<<"training error "<<error<<"\n";
*/
	return 0;
}
开发者ID:Mohit-Chachada,项目名称:Vision_Test,代码行数:95,代码来源:training_dtree_sample.cpp

示例9: clear

bool
CvGBTrees::train( const CvMat* _train_data, int _tflag,
              const CvMat* _responses, const CvMat* _var_idx,
              const CvMat* _sample_idx, const CvMat* _var_type,
              const CvMat* _missing_mask,
              CvGBTreesParams _params, bool /*_update*/ ) //update is not supported
{
    CvMemStorage* storage = 0;

    params = _params;
    bool is_regression = problem_type();

    clear();
    /*
      n - count of samples
      m - count of variables
    */
    int n = _train_data->rows;
    int m = _train_data->cols;
    if (_tflag != CV_ROW_SAMPLE)
    {
        int tmp;
        CV_SWAP(n,m,tmp);
    }

    CvMat* new_responses = cvCreateMat( n, 1, CV_32F);
    cvZero(new_responses);

    data = new CvDTreeTrainData( _train_data, _tflag, new_responses, _var_idx,
        _sample_idx, _var_type, _missing_mask, _params, true, true );
    if (_missing_mask)
    {
        missing = cvCreateMat(_missing_mask->rows, _missing_mask->cols,
                              _missing_mask->type);
        cvCopy( _missing_mask, missing);
    }

    orig_response = cvCreateMat( 1, n, CV_32F );
	int step = (_responses->cols > _responses->rows) ? 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
    switch (CV_MAT_TYPE(_responses->type))
    {
        case CV_32FC1:
		{
			for (int i=0; i<n; ++i)
                orig_response->data.fl[i] = _responses->data.fl[i*step];
		}; break;
        case CV_32SC1:
        {
            for (int i=0; i<n; ++i)
                orig_response->data.fl[i] = (float) _responses->data.i[i*step];
        }; break;
        default:
            CV_Error(CV_StsUnmatchedFormats, "Response should be a 32fC1 or 32sC1 vector.");
    }

    if (!is_regression)
    {
        class_count = 0;
        unsigned char * mask = new unsigned char[n];
        memset(mask, 0, n);
        // compute the count of different output classes
        for (int i=0; i<n; ++i)
            if (!mask[i])
            {
                class_count++;
                for (int j=i; j<n; ++j)
                    if (int(orig_response->data.fl[j]) == int(orig_response->data.fl[i]))
                        mask[j] = 1;
            }
        delete[] mask;
    
        class_labels = cvCreateMat(1, class_count, CV_32S);
        class_labels->data.i[0] = int(orig_response->data.fl[0]);
        int j = 1;
        for (int i=1; i<n; ++i)
        {
            int k = 0;
            while ((int(orig_response->data.fl[i]) - class_labels->data.i[k]) && (k<j))
                k++;
            if (k == j)
            {
                class_labels->data.i[k] = int(orig_response->data.fl[i]);
                j++;
            }
        }
    }

    // inside gbt learning proccess only regression decision trees are built
    data->is_classifier = false;

    // preproccessing sample indices
    if (_sample_idx)
    {
        int sample_idx_len = get_len(_sample_idx);
        
        switch (CV_MAT_TYPE(_sample_idx->type))
        {
            case CV_32SC1:
            {
                sample_idx = cvCreateMat( 1, sample_idx_len, CV_32S );
//.........这里部分代码省略.........
开发者ID:Ashwini7,项目名称:smart-python-programs,代码行数:101,代码来源:gbt.cpp

示例10: main

int main()
{
    const int train_sample_count = 300;
    bool is_regression = false;

    const char* filename = "data/waveform.data";
    int response_idx = 21;

    CvMLData data;

    CvTrainTestSplit spl( train_sample_count );
    
    if(data.read_csv(filename) != 0)
    {
        printf("couldn't read %s\n", filename);
        exit(0);
    }

    data.set_response_idx(response_idx);
    data.change_var_type(response_idx, CV_VAR_CATEGORICAL);
    data.set_train_test_split( &spl );

    const CvMat* values = data.get_values();
    const CvMat* response = data.get_responses();
    const CvMat* missing = data.get_missing();
    const CvMat* var_types = data.get_var_types();
    const CvMat* train_sidx = data.get_train_sample_idx();
    const CvMat* var_idx = data.get_var_idx();
    CvMat*response_map;
    CvMat*ordered_response = cv_preprocess_categories(response, var_idx, response->rows, &response_map, NULL);
    int num_classes = response_map->cols;
    
    CvDTree dtree;
    printf("======DTREE=====\n");
    CvDTreeParams cvd_params( 10, 1, 0, false, 16, 0, false, false, 0);
    dtree.train( &data, cvd_params);
    print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );

#if 0
    /* boosted trees are only implemented for two classes */
    printf("======BOOST=====\n");
    CvBoost boost;
    boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
    print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR), 0 );
#endif

    printf("======RTREES=====\n");
    CvRTrees rtrees;
    rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
    print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );

    printf("======ERTREES=====\n");
    CvERTrees ertrees;
    ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
    print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );

    printf("======GBTREES=====\n");
    CvGBTrees gbtrees;
    CvGBTreesParams gbparams;
    gbparams.loss_function_type = CvGBTrees::DEVIANCE_LOSS; // classification, not regression
    gbtrees.train( &data, gbparams);
    
    //gbt_print_error(&gbtrees, values, response, response_idx, train_sidx);
    print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0);

    printf("======KNEAREST=====\n");
    CvKNearest knearest;
    //bool CvKNearest::train( const Mat& _train_data, const Mat& _responses,
    //                const Mat& _sample_idx, bool _is_regression,
    //                int _max_k, bool _update_base )
    bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
    assert(is_classifier);
    int max_k = 10;
    knearest.train(values, response, train_sidx, is_regression, max_k, false);

    CvMat* new_response = cvCreateMat(response->rows, 1, values->type);
    //print_types();

    //const CvMat* train_sidx = data.get_train_sample_idx();
    knearest.find_nearest(values, max_k, new_response, 0, 0, 0);

    print_result(knearest_calc_error(values, response, new_response, train_sidx, is_regression, CV_TRAIN_ERROR),
                 knearest_calc_error(values, response, new_response, train_sidx, is_regression, CV_TEST_ERROR), 0);

    printf("======== RBF SVM =======\n");
    //printf("indexes: %d / %d, responses: %d\n", train_sidx->cols, var_idx->cols, values->rows);
    CvMySVM svm1;
    CvSVMParams params1 = CvSVMParams(CvSVM::C_SVC, CvSVM::RBF,
                                     /*degree*/0, /*gamma*/1, /*coef0*/0, /*C*/1,
                                     /*nu*/0, /*p*/0, /*class_weights*/0,
                                     cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
    //svm1.train(values, response, train_sidx, var_idx, params1);
    svm1.train_auto(values, response, var_idx, train_sidx, params1);
    svm_print_error(&svm1, values, response, response_idx, train_sidx);

    printf("======== Linear SVM =======\n");
    CvMySVM svm2;
    CvSVMParams params2 = CvSVMParams(CvSVM::C_SVC, CvSVM::LINEAR,
                                     /*degree*/0, /*gamma*/1, /*coef0*/0, /*C*/1,
                                     /*nu*/0, /*p*/0, /*class_weights*/0,
//.........这里部分代码省略.........
开发者ID:JackieXie168,项目名称:mrscake,代码行数:101,代码来源:test_cv.cpp


注:本文中的CvDTree::train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。