当前位置: 首页>>代码示例>>C++>>正文


C++ LinearRegression::train方法代码示例

本文整理汇总了C++中LinearRegression::train方法的典型用法代码示例。如果您正苦于以下问题:C++ LinearRegression::train方法的具体用法?C++ LinearRegression::train怎么用?C++ LinearRegression::train使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在LinearRegression的用法示例。


在下文中一共展示了LinearRegression::train方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。

示例1: main

int main() {
    LinearRegression lr;
    vector<DataGroup> train_set;

    DataGroup train1, train2, train3, train4, train5;
    //y = 3 * x1 - 5 * x2 + 3;
    DataGroup test1;
    //1  1 1
    train1.in.push_back(1); train1.in.push_back(1); train1.out.push_back(1);
    //1 -1 1
    train2.in.push_back(2); train2.in.push_back(1); train2.out.push_back(4);
    //-1 1 1
    train3.in.push_back(-1); train3.in.push_back(3); train3.out.push_back(-15);
    //0  0 1
    train4.in.push_back(9); train4.in.push_back(1); train4.out.push_back(25);
    //-1 -1 0
    train5.in.push_back(0); train5.in.push_back(1); train5.out.push_back(-2);
    train_set.push_back(train1);
    train_set.push_back(train2);
    train_set.push_back(train3);
    train_set.push_back(train4);
    train_set.push_back(train5);
    lr.train(train_set, 1e-5);

    test1.in.push_back(1); test1.in.push_back(1); test1.out.push_back(0);
    lr.predict(test1);
    cout << "result:" << test1.out[0] << endl;
    return 0;
}
开发者ID:Markz2z,项目名称:Linear-Regression,代码行数:29,代码来源:Test.cpp

示例2: test_linearity

void test_linearity() {
	Radix r;
	SimpleLSystemWithBranching sls;

	int b_index = 3;
	int grid_size = 10;
	//int grid_size = 5;

	int N = 10000;
	cv::Mat_<double> X(N, (grid_size - b_index - 2) * 2 + b_index);
	//cv::Mat_<double> X(N, (grid_size - 3) * 2 + 1);
	cv::Mat_<double> Y(N, grid_size * grid_size);
	
	for (int i = 0; i < N; ++i) {
		cv::Mat_<double> param(1, X.cols);
		for (int c = 0; c < X.cols; ++c) {
			param(0, c) = rand() % 3 - 1;
		}
		param.copyTo(X.row(i));
	}

	cv::Mat_<double> X2(X.rows, X.cols);

	int count = 0;
	for (int i = 0; i < X.rows; ++i) {
		try {
			//cv::Mat_<double> density = sls.computeDensity(grid_size, X.row(i), true, true);
			cv::Mat_<double> density = sls.computeDensity(grid_size, X.row(i), true, false);
			density.copyTo(Y.row(count));
			X.row(i).copyTo(X2.row(count));
			count++;
		} catch (char* ex) {
			//cout << "conflict" << endl;
		}
			
	}

	ml::saveDataset("dataX.txt", X2(cv::Rect(0, 0, X2.cols, count)));
	ml::saveDataset("dataY.txt", Y(cv::Rect(0, 0, Y.cols, count)));


	//cv::Mat_<double> X, Y;
	ml::loadDataset("dataX.txt", X);
	ml::loadDataset("dataY.txt", Y);

	cv::Mat_<double> trainX, trainY;
	cv::Mat_<double> testX, testY;
	ml::splitDataset(X, 0.8, trainX, testX);
	ml::splitDataset(Y, 0.8, trainY, testY);

	// Forward
	{
		LinearRegression lr;
		lr.train(trainX, trainY);
		cv::Mat_<double> Y_hat = lr.predict(testX);

		cv::Mat_<double> Y_avg;
		cv::reduce(trainY, Y_avg, 0, CV_REDUCE_AVG);
		Y_avg = cv::repeat(Y_avg, testY.rows, 1);

		cout << "-----------------------" << endl;
		cout << "Forward:" << endl;
		cout << "RMSE: " << ml::rmse(testY, Y_hat, true) << endl;
		cout << "Baselime: " << ml::rmse(testY, Y_avg, true) << endl;
	}

	// Inverse
	{
		LinearRegression lr;
		lr.train(trainY, trainX);
		
		cv::Mat_<double> X_hat = lr.predict(testY);

		// Xの各値を-1,0,1にdiscretizeする
		{
			for (int r = 0; r < X_hat.rows; ++r) {
				for (int c = 0; c < X_hat.cols; ++c) {
					if (X_hat(r, c) < -0.5) {
						X_hat(r, c) = -1;
					} else if (X_hat(r,c ) > 0.5) {
						X_hat(r, c) = 1;
					} else {
						X_hat(r, c) = 0;
					}
				}
			}
		}
		
		for (int i = 0; i < testX.cols; ++i) {
			cout << ml::rmse(testX.col(i), X_hat.col(i), true) << endl;
		}


		cv::Mat_<double> Y_hat(testY.rows, testY.cols);
		for (int i = 0; i < testX.rows; ++i) {
			cv::Mat_<double> density_hat = sls.computeDensity(grid_size, X_hat.row(i), true, false);
			density_hat.copyTo(Y_hat.row(i));
		}

		cv::Mat X_avg;
//.........这里部分代码省略.........
开发者ID:gnishida,项目名称:SimpleLSystem,代码行数:101,代码来源:main.cpp


注:本文中的LinearRegression::train方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。