当前位置: 首页>>代码示例>>Java>>正文


Java NormalizerStandardize.fit方法代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize.fit方法的典型用法代码示例。如果您正苦于以下问题:Java NormalizerStandardize.fit方法的具体用法?Java NormalizerStandardize.fit怎么用?Java NormalizerStandardize.fit使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize的用法示例。


在下文中一共展示了NormalizerStandardize.fit方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testBruteForce4d

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testBruteForce4d() {
    Construct4dDataSet imageDataSet = new Construct4dDataSet(10, 5, 10, 15);

    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    myNormalizer.fit(imageDataSet.sampleDataSet);
    assertEquals(imageDataSet.expectedMean, myNormalizer.getMean());

    float aat = Transforms.abs(myNormalizer.getStd().div(imageDataSet.expectedStd).sub(1)).maxNumber().floatValue();
    float abt = myNormalizer.getStd().maxNumber().floatValue();
    float act = imageDataSet.expectedStd.maxNumber().floatValue();
    System.out.println("ValA: " + aat);
    System.out.println("ValB: " + abt);
    System.out.println("ValC: " + act);
    assertTrue(aat < 0.05);

    NormalizerMinMaxScaler myMinMaxScaler = new NormalizerMinMaxScaler();
    myMinMaxScaler.fit(imageDataSet.sampleDataSet);
    assertEquals(imageDataSet.expectedMin, myMinMaxScaler.getMin());
    assertEquals(imageDataSet.expectedMax, myMinMaxScaler.getMax());

    DataSet copyDataSet = imageDataSet.sampleDataSet.copy();
    myNormalizer.transform(copyDataSet);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:25,代码来源:PreProcessor3D4DTest.java

示例2: testRevert

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testRevert() {
    double tolerancePerc = 0.01; // 0.01% of correct value
    int nSamples = 500;
    int nFeatures = 3;

    INDArray featureSet = Nd4j.randn(nSamples, nFeatures);
    INDArray labelSet = Nd4j.zeros(nSamples, 1);
    DataSet sampleDataSet = new DataSet(featureSet, labelSet);

    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    myNormalizer.fit(sampleDataSet);
    DataSet transformed = sampleDataSet.copy();
    myNormalizer.transform(transformed);
    //System.out.println(transformed.getFeatures());
    myNormalizer.revert(transformed);
    //System.out.println(transformed.getFeatures());
    INDArray delta = Transforms.abs(transformed.getFeatures().sub(sampleDataSet.getFeatures()))
                    .div(sampleDataSet.getFeatures());
    double maxdeltaPerc = delta.max(0, 1).mul(100).getDouble(0, 0);
    assertTrue(maxdeltaPerc < tolerancePerc);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:NormalizerStandardizeTest.java

示例3: testRestoreUnsavedNormalizerFromInputStream

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testRestoreUnsavedNormalizerFromInputStream() throws Exception {
    DataSet dataSet = trivialDataSet();

    NormalizerStandardize norm = new NormalizerStandardize();
    norm.fit(dataSet);

    ComputationGraph cg = simpleComputationGraph();
    cg.init();

    File tempFile = File.createTempFile("tsfs", "fdfsdf");
    tempFile.deleteOnExit();
    ModelSerializer.writeModel(cg, tempFile, true);

    FileInputStream fis = new FileInputStream(tempFile);

    NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis);

    assertEquals(null, restored);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:ModelSerializerTest.java

示例4: testMeanStdZeros

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testMeanStdZeros() {
    List<List<Writable>> data = new ArrayList<>();
    Schema.Builder builder = new Schema.Builder();
    int numColumns = 6;
    for (int i = 0; i < numColumns; i++)
        builder.addColumnDouble(String.valueOf(i));

    for (int i = 0; i < 5; i++) {
        List<Writable> record = new ArrayList<>(numColumns);
        data.add(record);
        for (int j = 0; j < numColumns; j++) {
            record.add(new DoubleWritable(1.0));
        }

    }

    INDArray arr = RecordConverter.toMatrix(data);

    Schema schema = builder.build();
    JavaRDD<List<Writable>> rdd = sc.parallelize(data);
    DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);

    //assert equivalent to the ndarray pre processing
    NormalizerStandardize standardScaler = new NormalizerStandardize();
    standardScaler.fit(new DataSet(arr.dup(), arr.dup()));
    INDArray standardScalered = arr.dup();
    standardScaler.transform(new DataSet(standardScalered, standardScalered));
    DataNormalization zeroToOne = new NormalizerMinMaxScaler();
    zeroToOne.fit(new DataSet(arr.dup(), arr.dup()));
    INDArray zeroToOnes = arr.dup();
    zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes));
    List<Row> rows = Normalization.stdDevMeanColumns(dataFrame, dataFrame.get().columns());
    INDArray assertion = DataFrames.toMatrix(rows);
    //compare standard deviation
    assertTrue(standardScaler.getStd().equalsWithEps(assertion.getRow(0), 1e-1));
    //compare mean
    assertTrue(standardScaler.getMean().equalsWithEps(assertion.getRow(1), 1e-1));

}
 
开发者ID:deeplearning4j,项目名称:DataVec,代码行数:41,代码来源:NormalizationTests.java

示例5: irisCsv

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
static DataIterator<NormalizerStandardize> irisCsv(String name) {
    CSVRecordReader recordReader = new CSVRecordReader(0, ",");
    try {
        recordReader.initialize(new FileSplit(new File(name)));
    } catch (Exception e) {
        e.printStackTrace();
    }

    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
    int batchSize = 50;     //Iris data set: 150 examples total.

    RecordReaderDataSetIterator iterator = new RecordReaderDataSetIterator(
            recordReader,
            batchSize,
            labelIndex,
            numClasses
    );

    NormalizerStandardize normalizer = new NormalizerStandardize();

    while (iterator.hasNext()) {
        normalizer.fit(iterator.next());
    }
    iterator.reset();

    iterator.setPreProcessor(normalizer);

    return new DataIterator<>(iterator, normalizer);
}
 
开发者ID:wmeddie,项目名称:dl4j-trainer-archetype,代码行数:31,代码来源:DataIterator.java

示例6: testBruteForce3d

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testBruteForce3d() {

    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    NormalizerMinMaxScaler myMinMaxScaler = new NormalizerMinMaxScaler();

    int timeSteps = 15;
    int samples = 100;
    //multiplier for the features
    INDArray featureScaleA = Nd4j.create(new double[] {1, -2, 3}).reshape(3, 1);
    INDArray featureScaleB = Nd4j.create(new double[] {2, 2, 3}).reshape(3, 1);

    Construct3dDataSet caseA = new Construct3dDataSet(featureScaleA, timeSteps, samples, 1);
    Construct3dDataSet caseB = new Construct3dDataSet(featureScaleB, timeSteps, samples, 1);

    myNormalizer.fit(caseA.sampleDataSet);
    assertEquals(caseA.expectedMean, myNormalizer.getMean());
    assertTrue(Transforms.abs(myNormalizer.getStd().div(caseA.expectedStd).sub(1)).maxNumber().floatValue() < 0.01);

    myMinMaxScaler.fit(caseB.sampleDataSet);
    assertEquals(caseB.expectedMin, myMinMaxScaler.getMin());
    assertEquals(caseB.expectedMax, myMinMaxScaler.getMax());

    //Same Test with an Iterator, values should be close for std, exact for everything else
    DataSetIterator sampleIterA = new TestDataSetIterator(caseA.sampleDataSet, 5);
    DataSetIterator sampleIterB = new TestDataSetIterator(caseB.sampleDataSet, 5);

    myNormalizer.fit(sampleIterA);
    assertEquals(myNormalizer.getMean(), caseA.expectedMean);
    assertTrue(Transforms.abs(myNormalizer.getStd().div(caseA.expectedStd).sub(1)).maxNumber().floatValue() < 0.01);

    myMinMaxScaler.fit(sampleIterB);
    assertEquals(myMinMaxScaler.getMin(), caseB.expectedMin);
    assertEquals(myMinMaxScaler.getMax(), caseB.expectedMax);

}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:37,代码来源:PreProcessor3D4DTest.java

示例7: testDifferentBatchSizes

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testDifferentBatchSizes() {
    // Create 6x1 matrix of the numbers 1 through 6
    INDArray values = Nd4j.linspace(1, 6, 6).transpose();
    DataSet dataSet = new DataSet(values, values);

    // Test fitting a DataSet
    NormalizerStandardize norm1 = new NormalizerStandardize();
    norm1.fit(dataSet);
    assertEquals(3.5f, norm1.getMean().getFloat(0), 1e-6);
    assertEquals(1.70783f, norm1.getStd().getFloat(0), 1e-4);

    // Test fitting an iterator with equal batch sizes
    DataSetIterator testIter1 = new TestDataSetIterator(dataSet, 3); // Will yield 2 batches of 3 rows
    NormalizerStandardize norm2 = new NormalizerStandardize();
    norm2.fit(testIter1);
    assertEquals(3.5f, norm2.getMean().getFloat(0), 1e-6);
    assertEquals(1.70783f, norm2.getStd().getFloat(0), 1e-4);

    // Test fitting an iterator with varying batch sizes
    DataSetIterator testIter2 = new TestDataSetIterator(dataSet, 4); // Will yield batch of 4 and batch of 2 rows
    NormalizerStandardize norm3 = new NormalizerStandardize();
    norm3.fit(testIter2);
    assertEquals(3.5f, norm3.getMean().getFloat(0), 1e-6);
    assertEquals(1.70783f, norm3.getStd().getFloat(0), 1e-4);

    // Test fitting an iterator with batches of single rows
    DataSetIterator testIter3 = new TestDataSetIterator(dataSet, 1); // Will yield 6 batches of 1 row
    NormalizerStandardize norm4 = new NormalizerStandardize();
    norm4.fit(testIter3);
    assertEquals(3.5f, norm4.getMean().getFloat(0), 1e-6);
    assertEquals(1.70783f, norm4.getStd().getFloat(0), 1e-4);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:34,代码来源:NormalizerStandardizeTest.java

示例8: testUnderOverflow

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testUnderOverflow() {
    // This dataset will be basically constant with a small std deviation
    // And the constant is large. Checking if algorithm can handle
    double tolerancePerc = 1; //Within 1 %
    double toleranceAbs = 0.0005;
    int nSamples = 1000;
    int bSize = 10;
    int x = -1000000, y = 1000000;
    double z = 1000000;

    INDArray featureX = Nd4j.rand(nSamples, 1).mul(1).add(x);
    INDArray featureY = Nd4j.rand(nSamples, 1).mul(2).add(y);
    INDArray featureZ = Nd4j.rand(nSamples, 1).mul(3).add(z);
    INDArray featureSet = Nd4j.concat(1, featureX, featureY, featureZ);
    INDArray labelSet = Nd4j.zeros(nSamples, 1);
    DataSet sampleDataSet = new DataSet(featureSet, labelSet);
    DataSetIterator sampleIter = new TestDataSetIterator(sampleDataSet, bSize);

    INDArray theoreticalMean = Nd4j.create(new double[] {x, y, z});

    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    myNormalizer.fit(sampleIter);

    INDArray meanDelta = Transforms.abs(theoreticalMean.sub(myNormalizer.getMean()));
    INDArray meanDeltaPerc = meanDelta.mul(100).div(theoreticalMean);
    assertTrue(meanDeltaPerc.max(1).getDouble(0, 0) < tolerancePerc);

    //this just has to not barf
    //myNormalizer.transform(sampleIter);
    myNormalizer.transform(sampleDataSet);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:33,代码来源:NormalizerStandardizeTest.java

示例9: testConstant

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testConstant() {
    double tolerancePerc = 10.0; // 10% of correct value
    int nSamples = 500;
    int nFeatures = 3;
    int constant = 100;

    INDArray featureSet = Nd4j.zeros(nSamples, nFeatures).add(constant);
    INDArray labelSet = Nd4j.zeros(nSamples, 1);
    DataSet sampleDataSet = new DataSet(featureSet, labelSet);


    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    myNormalizer.fit(sampleDataSet);
    //Checking if we gets nans
    assertFalse(Double.isNaN(myNormalizer.getStd().getDouble(0)));

    myNormalizer.transform(sampleDataSet);
    //Checking if we gets nans, because std dev is zero
    assertFalse(Double.isNaN(sampleDataSet.getFeatures().min(0, 1).getDouble(0)));
    //Checking to see if transformed values are close enough to zero
    assertEquals(Transforms.abs(sampleDataSet.getFeatures()).max(0, 1).getDouble(0, 0), 0,
                    constant * tolerancePerc / 100.0);

    myNormalizer.revert(sampleDataSet);
    //Checking if we gets nans, because std dev is zero
    assertFalse(Double.isNaN(sampleDataSet.getFeatures().min(0, 1).getDouble(0)));
    assertEquals(Transforms.abs(sampleDataSet.getFeatures().sub(featureSet)).min(0, 1).getDouble(0), 0,
                    constant * tolerancePerc / 100.0);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:31,代码来源:NormalizerStandardizeTest.java

示例10: normalize

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Override
public void normalize() {
    //FeatureUtil.normalizeMatrix(getFeatures());
    NormalizerStandardize inClassPreProcessor = new NormalizerStandardize();
    inClassPreProcessor.fit(this);
    inClassPreProcessor.transform(this);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:8,代码来源:DataSet.java

示例11: testRocMultiToHtml

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testRocMultiToHtml() throws Exception {
    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
                                    new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
                                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    NormalizerStandardize ns = new NormalizerStandardize();
    DataSet ds = iter.next();
    ns.fit(ds);
    ns.transform(ds);

    for (int i = 0; i < 30; i++) {
        net.fit(ds);
    }

    for (int numSteps : new int[] {20, 0}) {
        ROCMultiClass roc = new ROCMultiClass(numSteps);
        iter.reset();

        INDArray f = ds.getFeatures();
        INDArray l = ds.getLabels();
        INDArray out = net.output(f);
        roc.eval(l, out);


        String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
        System.out.println(str);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:36,代码来源:EvaluationToolsTests.java

示例12: normalizationTests

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void normalizationTests() {
    List<List<Writable>> data = new ArrayList<>();
    Schema.Builder builder = new Schema.Builder();
    int numColumns = 6;
    for (int i = 0; i < numColumns; i++)
        builder.addColumnDouble(String.valueOf(i));

    for (int i = 0; i < 5; i++) {
        List<Writable> record = new ArrayList<>(numColumns);
        data.add(record);
        for (int j = 0; j < numColumns; j++) {
            record.add(new DoubleWritable(1.0));
        }

    }

    INDArray arr = RecordConverter.toMatrix(data);

    Schema schema = builder.build();
    JavaRDD<List<Writable>> rdd = sc.parallelize(data);
    assertEquals(schema, DataFrames.fromStructType(DataFrames.fromSchema(schema)));
    assertEquals(rdd.collect(), DataFrames.toRecords(DataFrames.toDataFrame(schema, rdd)).getSecond().collect());

    DataRowsFacade dataFrame = DataFrames.toDataFrame(schema, rdd);
    dataFrame.get().show();
    Normalization.zeromeanUnitVariance(dataFrame).get().show();
    Normalization.normalize(dataFrame).get().show();

    //assert equivalent to the ndarray pre processing
    NormalizerStandardize standardScaler = new NormalizerStandardize();
    standardScaler.fit(new DataSet(arr.dup(), arr.dup()));
    INDArray standardScalered = arr.dup();
    standardScaler.transform(new DataSet(standardScalered, standardScalered));
    DataNormalization zeroToOne = new NormalizerMinMaxScaler();
    zeroToOne.fit(new DataSet(arr.dup(), arr.dup()));
    INDArray zeroToOnes = arr.dup();
    zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes));

    INDArray zeroMeanUnitVarianceDataFrame =
                    RecordConverter.toMatrix(Normalization.zeromeanUnitVariance(schema, rdd).collect());
    INDArray zeroMeanUnitVarianceDataFrameZeroToOne =
                    RecordConverter.toMatrix(Normalization.normalize(schema, rdd).collect());
    assertEquals(standardScalered, zeroMeanUnitVarianceDataFrame);
    assertTrue(zeroToOnes.equalsWithEps(zeroMeanUnitVarianceDataFrameZeroToOne, 1e-1));

}
 
开发者ID:deeplearning4j,项目名称:DataVec,代码行数:48,代码来源:NormalizationTests.java

示例13: testBruteForce

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testBruteForce() {
    /* This test creates a dataset where feature values are multiples of consecutive natural numbers
       The obtained values are compared to the theoretical mean and std dev
     */
    double tolerancePerc = 0.01;
    int nSamples = 5120;
    int x = 1, y = 2, z = 3;

    INDArray featureX = Nd4j.linspace(1, nSamples, nSamples).reshape(nSamples, 1).mul(x);
    INDArray featureY = featureX.mul(y);
    INDArray featureZ = featureX.mul(z);
    INDArray featureSet = Nd4j.concat(1, featureX, featureY, featureZ);
    INDArray labelSet = featureSet.dup().getColumns(new int[] {0});
    DataSet sampleDataSet = new DataSet(featureSet, labelSet);

    double meanNaturalNums = (nSamples + 1) / 2.0;
    INDArray theoreticalMean =
                    Nd4j.create(new double[] {meanNaturalNums * x, meanNaturalNums * y, meanNaturalNums * z});
    INDArray theoreticallabelMean = theoreticalMean.dup().getColumns(new int[] {0});
    double stdNaturalNums = Math.sqrt((nSamples * nSamples - 1) / 12.0);
    INDArray theoreticalStd =
                    Nd4j.create(new double[] {stdNaturalNums * x, stdNaturalNums * y, stdNaturalNums * z});
    INDArray theoreticallabelStd = theoreticalStd.dup().getColumns(new int[] {0});

    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    myNormalizer.fitLabel(true);
    myNormalizer.fit(sampleDataSet);

    INDArray meanDelta = Transforms.abs(theoreticalMean.sub(myNormalizer.getMean()));
    INDArray labelDelta = Transforms.abs(theoreticallabelMean.sub(myNormalizer.getLabelMean()));
    INDArray meanDeltaPerc = meanDelta.div(theoreticalMean).mul(100);
    INDArray labelDeltaPerc = labelDelta.div(theoreticallabelMean).mul(100);
    double maxMeanDeltaPerc = meanDeltaPerc.max(1).getDouble(0, 0);
    assertTrue(maxMeanDeltaPerc < tolerancePerc);
    assertTrue(labelDeltaPerc.max(1).getDouble(0, 0) < tolerancePerc);

    INDArray stdDelta = Transforms.abs(theoreticalStd.sub(myNormalizer.getStd()));
    INDArray stdDeltaPerc = stdDelta.div(theoreticalStd).mul(100);
    INDArray stdlabelDeltaPerc =
                    Transforms.abs(theoreticallabelStd.sub(myNormalizer.getLabelStd())).div(theoreticallabelStd);
    double maxStdDeltaPerc = stdDeltaPerc.max(1).mul(100).getDouble(0, 0);
    double maxlabelStdDeltaPerc = stdlabelDeltaPerc.max(1).getDouble(0, 0);
    assertTrue(maxStdDeltaPerc < tolerancePerc);
    assertTrue(maxlabelStdDeltaPerc < tolerancePerc);


    // SAME TEST WITH THE ITERATOR
    int bSize = 10;
    tolerancePerc = 0.1; // 1% of correct value
    DataSetIterator sampleIter = new TestDataSetIterator(sampleDataSet, bSize);
    myNormalizer.fit(sampleIter);

    meanDelta = Transforms.abs(theoreticalMean.sub(myNormalizer.getMean()));
    meanDeltaPerc = meanDelta.div(theoreticalMean).mul(100);
    maxMeanDeltaPerc = meanDeltaPerc.max(1).getDouble(0, 0);
    assertTrue(maxMeanDeltaPerc < tolerancePerc);

    stdDelta = Transforms.abs(theoreticalMean.sub(myNormalizer.getMean()));
    stdDeltaPerc = stdDelta.div(theoreticalStd).mul(100);
    maxStdDeltaPerc = stdDeltaPerc.max(1).getDouble(0, 0);
    assertTrue(maxStdDeltaPerc < tolerancePerc);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:64,代码来源:NormalizerStandardizeLabelsTest.java

示例14: testTransform

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testTransform() {
    /*Random dataset is generated such that
        AX + B where X is from a normal distribution with mean 0 and std 1
        The mean of above will be B and std A
        Obtained mean and std dev are compared to theoretical
        Transformed values should be the same as X with the same seed.
     */
    long randSeed = 2227724;

    int nFeatures = 2;
    int nSamples = 6400;
    int bsize = 8;
    int a = 5;
    int b = 100;
    INDArray sampleMean, sampleStd, sampleMeanDelta, sampleStdDelta, delta, deltaPerc;
    double maxDeltaPerc, sampleMeanSEM;

    genRandomDataSet normData = new genRandomDataSet(nSamples, nFeatures, a, b, randSeed);
    genRandomDataSet expectedData = new genRandomDataSet(nSamples, nFeatures, 1, 0, randSeed);
    genRandomDataSet beforeTransformData = new genRandomDataSet(nSamples, nFeatures, a, b, randSeed);

    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    myNormalizer.fitLabel(true);
    DataSetIterator normIterator = normData.getIter(bsize);
    DataSetIterator expectedIterator = expectedData.getIter(bsize);
    DataSetIterator beforeTransformIterator = beforeTransformData.getIter(bsize);

    myNormalizer.fit(normIterator);

    double tolerancePerc = 0.5; //within 0.5%
    sampleMean = myNormalizer.getMean();
    sampleMeanDelta = Transforms.abs(sampleMean.sub(normData.theoreticalMean));
    assertTrue(sampleMeanDelta.mul(100).div(normData.theoreticalMean).max(1).getDouble(0, 0) < tolerancePerc);
    //sanity check to see if it's within the theoretical standard error of mean
    sampleMeanSEM = sampleMeanDelta.div(normData.theoreticalSEM).max(1).getDouble(0, 0);
    assertTrue(sampleMeanSEM < 2.6); //99% of the time it should be within this many SEMs

    tolerancePerc = 5; //within 5%
    sampleStd = myNormalizer.getStd();
    sampleStdDelta = Transforms.abs(sampleStd.sub(normData.theoreticalStd));
    assertTrue(sampleStdDelta.div(normData.theoreticalStd).max(1).mul(100).getDouble(0, 0) < tolerancePerc);

    tolerancePerc = 1; //within 1%
    normIterator.setPreProcessor(myNormalizer);
    while (normIterator.hasNext()) {
        INDArray before = beforeTransformIterator.next().getFeatures();
        DataSet here = normIterator.next();
        assertEquals(here.getFeatures(), here.getLabels()); //bootstrapping existing test on features
        INDArray after = here.getFeatures();
        INDArray expected = expectedIterator.next().getFeatures();
        delta = Transforms.abs(after.sub(expected));
        deltaPerc = delta.div(before.sub(expected));
        deltaPerc.muli(100);
        maxDeltaPerc = deltaPerc.max(0, 1).getDouble(0, 0);
        //System.out.println("=== BEFORE ===");
        //System.out.println(before);
        //System.out.println("=== AFTER ===");
        //System.out.println(after);
        //System.out.println("=== SHOULD BE ===");
        //System.out.println(expected);
        assertTrue(maxDeltaPerc < tolerancePerc);
    }
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:65,代码来源:NormalizerStandardizeLabelsTest.java

示例15: testBruteForce3dMaskLabels

import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; //导入方法依赖的package包/类
@Test
public void testBruteForce3dMaskLabels() {

    NormalizerStandardize myNormalizer = new NormalizerStandardize();
    myNormalizer.fitLabel(true);
    NormalizerMinMaxScaler myMinMaxScaler = new NormalizerMinMaxScaler();
    myMinMaxScaler.fitLabel(true);

    //generating a dataset with consecutive numbers as feature values. Dataset also has masks
    int samples = 100;
    INDArray featureScale = Nd4j.create(new float[] {1, 2, 10}).reshape(3, 1);
    int timeStepsU = 5;
    Construct3dDataSet sampleU = new Construct3dDataSet(featureScale, timeStepsU, samples, 1);
    int timeStepsV = 3;
    Construct3dDataSet sampleV = new Construct3dDataSet(featureScale, timeStepsV, samples, sampleU.newOrigin);
    List<DataSet> dataSetList = new ArrayList<DataSet>();
    dataSetList.add(sampleU.sampleDataSet);
    dataSetList.add(sampleV.sampleDataSet);

    DataSet fullDataSetA = DataSet.merge(dataSetList);
    DataSet fullDataSetAA = fullDataSetA.copy();
    //This should be the same datasets as above without a mask
    Construct3dDataSet fullDataSetNoMask =
                    new Construct3dDataSet(featureScale, timeStepsU + timeStepsV, samples, 1);

    //preprocessors - label and feature values are the same
    myNormalizer.fit(fullDataSetA);
    assertEquals(myNormalizer.getMean(), fullDataSetNoMask.expectedMean);
    assertEquals(myNormalizer.getStd(), fullDataSetNoMask.expectedStd);
    assertEquals(myNormalizer.getLabelMean(), fullDataSetNoMask.expectedMean);
    assertEquals(myNormalizer.getLabelStd(), fullDataSetNoMask.expectedStd);

    myMinMaxScaler.fit(fullDataSetAA);
    assertEquals(myMinMaxScaler.getMin(), fullDataSetNoMask.expectedMin);
    assertEquals(myMinMaxScaler.getMax(), fullDataSetNoMask.expectedMax);
    assertEquals(myMinMaxScaler.getLabelMin(), fullDataSetNoMask.expectedMin);
    assertEquals(myMinMaxScaler.getLabelMax(), fullDataSetNoMask.expectedMax);


    //Same Test with an Iterator, values should be close for std, exact for everything else
    DataSetIterator sampleIterA = new TestDataSetIterator(fullDataSetA, 5);
    DataSetIterator sampleIterB = new TestDataSetIterator(fullDataSetAA, 5);

    myNormalizer.fit(sampleIterA);
    assertEquals(myNormalizer.getMean(), fullDataSetNoMask.expectedMean);
    assertEquals(myNormalizer.getLabelMean(), fullDataSetNoMask.expectedMean);
    assertTrue(Transforms.abs(myNormalizer.getStd().div(fullDataSetNoMask.expectedStd).sub(1)).maxNumber()
                    .floatValue() < 0.01);
    assertTrue(Transforms.abs(myNormalizer.getLabelStd().div(fullDataSetNoMask.expectedStd).sub(1)).maxNumber()
                    .floatValue() < 0.01);

    myMinMaxScaler.fit(sampleIterB);
    assertEquals(myMinMaxScaler.getMin(), fullDataSetNoMask.expectedMin);
    assertEquals(myMinMaxScaler.getMax(), fullDataSetNoMask.expectedMax);
    assertEquals(myMinMaxScaler.getLabelMin(), fullDataSetNoMask.expectedMin);
    assertEquals(myMinMaxScaler.getLabelMax(), fullDataSetNoMask.expectedMax);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:58,代码来源:PreProcessor3D4DTest.java


注:本文中的org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize.fit方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。