本文整理汇总了Java中org.nd4j.linalg.dataset.SplitTestAndTrain.getTrain方法的典型用法代码示例。如果您正苦于以下问题:Java SplitTestAndTrain.getTrain方法的具体用法?Java SplitTestAndTrain.getTrain怎么用?Java SplitTestAndTrain.getTrain使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.dataset.SplitTestAndTrain
的用法示例。
在下文中一共展示了SplitTestAndTrain.getTrain方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: createDataSource
import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例2: createDataSource
import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入方法依赖的package包/类
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 11;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
DataSet allData = iterator.next();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例3: split
import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入方法依赖的package包/类
@Override
public IrisData split(int trainingDataSize)
{
final SplitTestAndTrain splits = dataset.splitTestAndTrain(trainingDataSize, new Random(seed));
trainingData = splits.getTrain();
testingData = splits.getTest();
return this;
}
示例4: main
import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入方法依赖的package包/类
public static void main(String[] args) throws Exception {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
char delimiter = ',';
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new File("src/main/resources/DL4J_Resources/iris.txt")));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
final int numInputs = 4;
int outputNum = 3;
int iterations = 1000;
long seed = 6;
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
.build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(outputNum).build())
.backprop(true).pretrain(false)
.build();
//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
model.fit(trainingData);
//evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatureMatrix());
eval.eval(testData.getLabels(), output);
log.info(eval.stats());
//Save the model
File locationToSave = new File("src/main/resources/generatedModels/DL4J/DL4J_Iris_Model.zip"); //Where to save the network. Note: the file is in .zip format - can be opened externally
boolean saveUpdater = true; //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
ModelSerializer.writeModel(model, locationToSave, saveUpdater);
//Load the model
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
System.out.println("Saved and loaded parameters are equal: " + model.params().equals(restored.params()));
System.out.println("Saved and loaded configurations are equal: " + model.getLayerWiseConfigurations().equals(restored.getLayerWiseConfigurations()));
}
开发者ID:kaiwaehner,项目名称:kafka-streams-machine-learning-examples,代码行数:78,代码来源:DeepLearning4J_CSV_Model.java
示例5: main
import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入方法依赖的package包/类
public static void main(String[] args) throws IOException, InterruptedException {
// source from: https://www.kaggle.com/uciml/glass
int numLinesToSkip = 1;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new ClassPathResource("glass/glass.csv").getFile()));
int labelIndex = 9;
int numClasses = 7;
int batchSize = 214; // totally, 214 data
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.8);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData);
normalizer.transform(trainingData);
normalizer.transform(testData);
int seed = 123;
int numInputs = 9;
int iterations = 1000;
int epochs = 1;
log.info("Construct model...");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS)
.momentum(0.9)
.learningRate(0.2)
.regularization(true)
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(numInputs)
.nOut(50)
.activation(Activation.TANH)
.build())
.layer(1, new DenseLayer.Builder()
.nOut(100)
.activation(Activation.TANH)
.build())
.layer(2, new OutputLayer.Builder()
.nOut(numClasses)
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.build())
.backprop(true)
.pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
for (int epoch = 0; epoch < epochs; epoch++) {
model.fit(trainingData);
log.info("*** Completed epoch {} ***", epoch);
log.info("Evaluate model....");
Evaluation eval = new Evaluation(numClasses);
INDArray output = model.output(testData.getFeatureMatrix());
eval.eval(testData.getLabels(), output);
log.info(eval.stats());
}
}
示例6: main
import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入方法依赖的package包/类
public static void main(String[] args) throws Exception {
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
int labelIndex = 4;
int numClasses = 3;
int batchSize = 150;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData);//Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
final int numInputs = 4;
int outputNum = 3;
int iterations = 1000;
long seed = 6;
int epochs = 100;
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.weightInit(WeightInit.XAVIER)
.learningRate(0.1)
.regularization(true)
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(numInputs)
.nOut(20)
.activation(Activation.TANH)
.build())
.layer(1, new DenseLayer.Builder()
.nOut(10)
.activation(Activation.TANH)
.build())
.layer(2, new OutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nOut(outputNum)
.build())
.backprop(true)
.pretrain(false)
.build();
//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(100));
for (int epoch = 0; epoch < epochs; epoch++) {
model.fit(trainingData);
}
//evaluate the model on the test set
Evaluation eval = new Evaluation(numClasses);
INDArray output = model.output(testData.getFeatureMatrix());
eval.eval(testData.getLabels(), output);
log.info(eval.stats());
}
示例7: testIris
import org.nd4j.linalg.dataset.SplitTestAndTrain; //导入方法依赖的package包/类
@Test
public void testIris() {
// Network config
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42)
.updater(new Sgd(1e-6)).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(
LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build())
.build();
// Instantiate model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(1)));
// Train-test split
DataSetIterator iter = new IrisDataSetIterator(150, 150);
DataSet next = iter.next();
next.shuffle();
SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42));
// Train
DataSet train = trainTest.getTrain();
train.normalizeZeroMeanZeroUnitVariance();
// Test
DataSet test = trainTest.getTest();
test.normalizeZeroMeanZeroUnitVariance();
INDArray testFeature = test.getFeatureMatrix();
INDArray testLabel = test.getLabels();
// Fitting model
model.fit(train);
// Get predictions from test feature
INDArray testPredictedLabel = model.output(testFeature);
// Eval with class number
Evaluation eval = new Evaluation(3); //// Specify class num here
eval.eval(testLabel, testPredictedLabel);
double eval1F1 = eval.f1();
double eval1Acc = eval.accuracy();
// Eval without class number
Evaluation eval2 = new Evaluation(); //// No class num
eval2.eval(testLabel, testPredictedLabel);
double eval2F1 = eval2.f1();
double eval2Acc = eval2.accuracy();
//Assert the two implementations give same f1 and accuracy (since one batch)
assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc);
Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator<>(Collections.singletonList(test)));
checkEvaluationEquality(eval, evalViaMethod);
System.out.println(eval.getConfusionMatrix().toString());
System.out.println(eval.getConfusionMatrix().toCSV());
System.out.println(eval.getConfusionMatrix().toHTML());
System.out.println(eval.confusionToString());
}