当前位置: 首页>>代码示例>>Java>>正文


Java Transforms.not方法代码示例

本文整理汇总了Java中org.nd4j.linalg.ops.transforms.Transforms.not方法的典型用法代码示例。如果您正苦于以下问题:Java Transforms.not方法的具体用法?Java Transforms.not怎么用?Java Transforms.not使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.ops.transforms.Transforms的用法示例。


在下文中一共展示了Transforms.not方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testNot1

import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
@Test
public void testNot1() {
    INDArray x = Nd4j.create(new double[] {0, 0, 1, 0, 0});
    INDArray exp = Nd4j.create(new double[] {1, 1, 0, 1, 1});

    INDArray z = Transforms.not(x);

    assertEquals(exp, z);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:10,代码来源:TransformsTest.java

示例2: adjustMasks

import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
public INDArray adjustMasks(INDArray label, INDArray labelMask, int minorityLabel, double targetDist) {

        if (labelMask == null) {
            labelMask = Nd4j.ones(label.size(0), label.size(2));
        }
        validateData(label, labelMask);

        INDArray bernoullis = Nd4j.zeros(labelMask.shape());
        int currentTimeSliceEnd = label.size(2);
        //iterate over each tbptt window
        while (currentTimeSliceEnd > 0) {

            int currentTimeSliceStart = Math.max(currentTimeSliceEnd - tbpttWindowSize, 0);

            //get views for current time slice
            INDArray currentWindowBernoulli = bernoullis.get(NDArrayIndex.all(),
                            NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
            INDArray currentMask = labelMask.get(NDArrayIndex.all(),
                            NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
            INDArray currentLabel;
            if (label.size(1) == 2) {
                //if one hot grab the right index
                currentLabel = label.get(NDArrayIndex.all(), NDArrayIndex.point(minorityLabel),
                                NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
            } else {
                currentLabel = label.get(NDArrayIndex.all(), NDArrayIndex.point(0),
                                NDArrayIndex.interval(currentTimeSliceStart, currentTimeSliceEnd));
                if (minorityLabel == 0) {
                    currentLabel = Transforms.not(currentLabel);
                }
            }

            //calculate required probabilities and write into the view
            currentWindowBernoulli.assign(calculateBernoulli(currentLabel, currentMask, targetDist));

            currentTimeSliceEnd = currentTimeSliceStart;
        }

        return Nd4j.getExecutioner().exec(
                        new BernoulliDistribution(Nd4j.createUninitialized(bernoullis.shape()), bernoullis),
                        Nd4j.getRandom());
    }
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:43,代码来源:BaseUnderSamplingPreProcessor.java

示例3: computeScoreNumDenom

import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
private double[] computeScoreNumDenom(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask,
                boolean average) {
    INDArray output = activationFn.getActivation(preOutput.dup(), true);

    int n = labels.size(1);
    if (n != 1 && n != 2) {
        throw new UnsupportedOperationException(
                        "For binary classification: expect output size of 1 or 2. Got: " + n);
    }

    //First: determine positives and negatives
    INDArray isPositiveLabel;
    INDArray isNegativeLabel;
    INDArray pClass0;
    INDArray pClass1;
    if (n == 1) {
        isPositiveLabel = labels;
        isNegativeLabel = Transforms.not(isPositiveLabel);
        pClass0 = output.rsub(1.0);
        pClass1 = output;
    } else {
        isPositiveLabel = labels.getColumn(1);
        isNegativeLabel = labels.getColumn(0);
        pClass0 = output.getColumn(0);
        pClass1 = output.getColumn(1);
    }

    if (mask != null) {
        isPositiveLabel = isPositiveLabel.mulColumnVector(mask);
        isNegativeLabel = isNegativeLabel.mulColumnVector(mask);
    }

    double tp = isPositiveLabel.mul(pClass1).sumNumber().doubleValue();
    double fp = isNegativeLabel.mul(pClass1).sumNumber().doubleValue();
    double fn = isPositiveLabel.mul(pClass0).sumNumber().doubleValue();

    double numerator = (1.0 + beta * beta) * tp;
    double denominator = (1.0 + beta * beta) * tp + beta * beta * fn + fp;

    return new double[] {numerator, denominator};
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:42,代码来源:LossFMeasure.java


注:本文中的org.nd4j.linalg.ops.transforms.Transforms.not方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。