本文整理汇总了C++中Parameters::GetEta方法的典型用法代码示例。如果您正苦于以下问题:C++ Parameters::GetEta方法的具体用法?C++ Parameters::GetEta怎么用?C++ Parameters::GetEta使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类Parameters
的用法示例。
在下文中一共展示了Parameters::GetEta方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的C++代码示例。
示例1: training
// Trains a neural net with the training set provided. The weights for the
// trained net will be stored in the file specified by the Parameter object.
void training(Parameters ¶ms, vector<YearData> &trainingSet)
{
string weights_filename = params.GetWeightsFileName();
// Create a neural net
Net ann(params.GetNodesPerLayer(), params.GetEta(), params.GetAlpha());
// Set the weights according to the weights file
ann.read_in_weights(weights_filename, params.GetNodesPerLayer());
// Get number of training epochs
int num_epochs = params.GetEpochs();
// Get error threshold for stopping
double error_thresh = params.GetErrorThresh();
// Only present floats to the thousandths place
cout << setprecision(3);
// Epoch loop
for (int i = 0; i < num_epochs; i++)
{
// Shuffle order of records
random_shuffle(trainingSet.begin(), trainingSet.end());
// Perform feed forward and back prop once for each record
for (unsigned int j = 0; j < trainingSet.size(); j++)
{
ann.feed_forward(trainingSet[j].inputs);
ann.back_prop(trainingSet[j].class_outputs);
}
// Stop if error threshold is reached
if (ann.get_avg_error() < error_thresh)
{
cout << "Epoch" << setw(7) << i << ": RMS error = "
<< ann.get_avg_error() << endl;
cout << "Error Threshold Met" << endl;
break;
}
// Print out the average RMS error for every ten epochs
if (i % 10 == 0)
{
cout << "Epoch" << setw(7) << i << ": RMS error = "
<< ann.get_avg_error() << endl;
ann.reset_avg_error();
}
}
// Save net's weights to the weight file
ann.print_weights(weights_filename);
return;
}
示例2: testing
// Tests a neural net with the testing set provided.
void testing(Parameters ¶ms, vector<YearData> &testSet)
{
int numberCorrect = 0; // count of number of correct predictions
bool low, mid, high; // bools to determine what the actual was
unsigned int i; // count of total number of test sets
double percentCorrect; // the percent of correct predictions
vector<double> outputsFromNet; // outputs from the net
vector<double> expected_outputs; // expected outputs
double error = 0.0; // current error of the output node
double avg_error = 0.0; // average error of the net
// create the net
Net ann(params.GetNodesPerLayer(), params.GetEta(), params.GetAlpha());
// read in weight file
ann.read_in_weights(params.GetWeightsFileName(), params.GetNodesPerLayer());
// output the type of format
cout << "Sample, Actual, Predicted" << endl;
// loop through each test set
for (i = 0; i < testSet.size(); i++)
{
// set the low, mid, and high bools to false
low = false;
mid = false;
high = false;
// perform the feed forward on the current test set
ann.feed_forward(testSet[i].inputs);
// get the output of the net for the current test set
ann.get_output(outputsFromNet);
// output the sample
cout << i << ", ";
// if the actual fire severity was mid
if (testSet[i].actualburnedacres > params.GetFireSeverityCutoffs().at(0) &&
testSet[i].actualburnedacres < params.GetFireSeverityCutoffs().at(1))
{
// output the mid
cout << "010, ";
mid = true;
}
// if the actual fire severity was high
else if (testSet[i].actualburnedacres > params.GetFireSeverityCutoffs().at(1))
{
// output the high
cout << "001, ";
high = true;
}
// if the acutal fire severity was low
else
{
// output the low
cout << "100, ";
low = true;
}
// if low fire severity is predicted
if (outputsFromNet.at(0) > outputsFromNet.at(1) &&
outputsFromNet.at(0) > outputsFromNet.at(2))
{
// output low
cout << "100";
// if it predicted correctly increment the count
if (low)
numberCorrect++;
// if its wrong star that line
else
cout << ", *";
}
// if mid fire severity is predicted
else if (outputsFromNet.at(1) > outputsFromNet.at(0) &&
outputsFromNet.at(1) > outputsFromNet.at(2))
{
// output mid
cout << "010";
// if it predicted correctly increment the count
if (mid)
numberCorrect++;
// if its wrong star that line
else
cout << ", *";
}
// if high fire severity is predcted
else
{
// output high
cout << "001";
// if it predicted correctly increment the count
if (high)
numberCorrect++;
// if its wrong star that line
else
//.........这里部分代码省略.........
示例3: crossValidate
void crossValidate(Parameters ¶ms, vector<YearData> &cvSet)
{
//split cv set into training set and testing set
vector<YearData> trainingSet;
YearData testSet;
int numberCorrect = 0;
bool low, mid, high;
double percentCorrect;
vector<double> outputsFromNet;
// Get number of training epochs
int num_epochs = params.GetEpochs();
// Get error threshold for stopping
double error_thresh = params.GetErrorThresh();
double avg_error;
string break_error_thresh;
cout << "Year, Burned, Actual, Predicted (training error)" << endl;
for (unsigned int q = 0; q < cvSet.size(); q++)
{
trainingSet = vector<YearData>(cvSet);
trainingSet.erase(trainingSet.begin() + q);
testSet = YearData(cvSet.at(q));
break_error_thresh.clear();
// Create a neural net
Net ann(params.GetNodesPerLayer(), params.GetEta(), params.GetAlpha());
// Only present floats to the thousandths place
cout << setprecision(3);
// Epoch loop
for (int i = 0; i < num_epochs; i++)
{
// Shuffle order of records
random_shuffle(trainingSet.begin(), trainingSet.end());
// Perform feed forward and back prop once for each record
for (unsigned int j = 0; j < trainingSet.size(); j++)
{
ann.feed_forward(trainingSet[j].inputs);
ann.back_prop(trainingSet[j].class_outputs);
}
// Stop if error threshold is reached
if (ann.get_avg_error() < error_thresh)
{
break_error_thresh = " Reached error threshold at epoch " + i;
break;
}
}
avg_error = ann.get_avg_error();
// set the low, mid, and high bools to false
low = false;
mid = false;
high = false;
// perform the feed forward on the current test set
ann.feed_forward(testSet.inputs);
// get the output of the net for the current test set
ann.get_output(outputsFromNet);
// output the sample
cout << testSet.year << ", " << setw(6) << right << testSet.actualburnedacres << ", ";
// if the actual fire severity was mid
if (testSet.actualburnedacres > params.GetFireSeverityCutoffs().at(0) &&
testSet.actualburnedacres < params.GetFireSeverityCutoffs().at(1))
{
// output the mid
cout << setw(8) << right << "010, ";
mid = true;
}
// if the actual fire severity was high
else if (testSet.actualburnedacres > params.GetFireSeverityCutoffs().at(1))
{
// output the high
cout << setw(8) << right << "001, ";
high = true;
}
// if the acutal fire severity was low
else
{
// output the low
cout << setw(8) << right << "100, ";
low = true;
}
// if low fire severity is predicted
if (outputsFromNet.at(0) > outputsFromNet.at(1) &&
outputsFromNet.at(0) > outputsFromNet.at(2))
{
// if it predicted correctly increment the count
if (low)
{
cout << setw(9) << right << "100 ";
//.........这里部分代码省略.........