本文整理汇总了Java中org.encog.neural.networks.training.Train类的典型用法代码示例。如果您正苦于以下问题:Java Train类的具体用法?Java Train怎么用?Java Train使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
Train类属于org.encog.neural.networks.training包,在下文中一共展示了Train类的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: trainAndStore
import org.encog.neural.networks.training.Train; //导入依赖的package包/类
@Test
public void trainAndStore() {
BasicMLDataSet dataSet = getData();
// Create network
BasicNetwork network = getNetwork();
// Train
System.out.println("Training network...");
Train train = new ResilientPropagation(network, dataSet);
for (int i = 0; i < TRAIN_ITERATIONS; i++) {
train.iteration();
}
System.out.println("Training finished, error: " + train.getError());
// Save to file
System.out.println("Saving to file...");
saveToFile(network);
System.out.println("Done");
}
示例2: getTrain
import org.encog.neural.networks.training.Train; //导入依赖的package包/类
private Train getTrain (NeuralDataSet trainingSet, BasicNetwork network) {
//final Train train =
//new ManhattanPropagation(network, trainingSet,
//0.001);
// Train the neural network, we use resilient propagation
final ResilientPropagation train = new ResilientPropagation(network, trainingSet);
train.setThreadCount(0);
// Reset if improve is less than 1% over 5 cycles
train.addStrategy(new RequiredImprovementStrategy(DEFAULT_SELECTION_LIMIT));
return train;
}
示例3: test
import org.encog.neural.networks.training.Train; //导入依赖的package包/类
public static void test(double[][] inputValues, double[][] outputValues)
{
NeuralDataSet trainingSet = new BasicNeuralDataSet(inputValues, outputValues);
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(new ActivationSigmoid(), false, 4));
network.addLayer(new BasicLayer(new ActivationSigmoid(), false, 1000));
network.addLayer(new BasicLayer(new ActivationLinear(), false, 1));
network.getStructure().finalizeStructure();
network.reset();
final Train train = new ResilientPropagation(network, trainingSet);
int epoch = 1;
do
{
train.iteration();
System.out.println("Epoch #" + epoch + " Error:" + train.getError());
epoch++;
}
while(epoch < 10000);
System.out.println("Neural Network Results:");
for(MLDataPair pair : trainingSet)
{
final MLData output = network.compute(pair.getInput());
System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1) + ", actual="
+ output.getData(0) + ",ideal=" + pair.getIdeal().getData(0));
}
}
示例4: trainOnModelStructureChange
import org.encog.neural.networks.training.Train; //导入依赖的package包/类
private void trainOnModelStructureChange(double[][] x, double[][] y, Class<?> activation, BasicNetwork network, Train train) throws InstantiationException, IllegalAccessException {
NeuralDataSet trainingSet = new BasicNeuralDataSet(x, y);
if (network == null && train == null) {
network = buildNetwork(trainingSet, DEFAULT_CHANGE_HIDDEN_NO, x[0].length, activation);
train = getTrain(trainingSet, network);
}
int epoch = 1;
long time = System.currentTimeMillis();
double error[];
double[] previousNetworkBestError = null;
do {
train.iteration();
error = getError(trainingSet, network);
logger.debug("Epoch #" + epoch + " Error:" + print(error));
epoch++;
if (isBetter(error,bestQuality) &&
(bestEverError == null || 0.2 <= betterPercentage(error,bestEverError))) {
bestEverNetwork = (BasicNetwork)network.clone();
bestEverError = error;
logger.debug("Epoch #" + epoch + " Best Ever Error:" + print(error));
previousNetworkBestError = error;
network = buildNetwork(trainingSet, currentNoOfHiddenNode + DEFAULT_CHANGE_HIDDEN_NO, x[0].length, activation);
train = getTrain(trainingSet, network);
time = System.currentTimeMillis();
continue;
}
// Change on hidden layer, retrain a new ANN
if(System.currentTimeMillis() - time > DEFAULT_STRUCTURE_SELECTION_TIME_LIMIT ) {
if (previousNetworkBestError != null && previousNetworkBestError.equals(bestEverError)) {
currentNoOfHiddenNode -= DEFAULT_CHANGE_HIDDEN_NO;
break;
}
previousNetworkBestError = error;
network = buildNetwork(trainingSet, currentNoOfHiddenNode + DEFAULT_CHANGE_HIDDEN_NO, x[0].length, activation);
train = getTrain(trainingSet, network);
time = System.currentTimeMillis();
}
} while (true);
logger.debug("Training Time: "+ (System.currentTimeMillis() - time));
logger.debug("Current number of hidden node: "+ currentNoOfHiddenNode);
calculateError(trainingSet);
}