本文整理汇总了Java中org.nd4j.linalg.ops.transforms.Transforms.max方法的典型用法代码示例。如果您正苦于以下问题:Java Transforms.max方法的具体用法?Java Transforms.max怎么用?Java Transforms.max使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.ops.transforms.Transforms
的用法示例。
在下文中一共展示了Transforms.max方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: testJaccardDistance
import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
@Test
public void testJaccardDistance(){
Nd4j.getRandom().setSeed(12345);
INDArray a = Nd4j.rand(new int[]{3,4}).addi(0.1);
INDArray b = Nd4j.rand(new int[]{3,4}).addi(0.1);
SameDiff sd = SameDiff.create();
SDVariable in1 = sd.var("in1", a);
SDVariable in2 = sd.var("in2", b);
SDVariable jaccard = sd.jaccardDistance("out", in1, in2);
INDArray min = Transforms.min(a,b);
INDArray max = Transforms.max(a,b);
double minSum = min.sumNumber().doubleValue();
double maxSum = max.sumNumber().doubleValue();
double jd = 1.0 - minSum / maxSum;
INDArray out = sd.execAndEndResult();
assertEquals(1, out.length());
assertEquals(jd, out.getDouble(0), 1e-6);
}
示例2: testScalarMinMax1
import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
@Test
public void testScalarMinMax1() {
INDArray x = Nd4j.create(new double[] {1, 3, 5, 7});
INDArray xCopy = x.dup();
INDArray exp1 = Nd4j.create(new double[] {1, 3, 5, 7});
INDArray exp2 = Nd4j.create(new double[] {1e-5, 1e-5, 1e-5, 1e-5});
INDArray z1 = Transforms.max(x, Nd4j.EPS_THRESHOLD, true);
INDArray z2 = Transforms.min(x, Nd4j.EPS_THRESHOLD, true);
assertEquals(exp1, z1);
assertEquals(exp2, z2);
// Assert that x was not modified
assertEquals(x, xCopy);
INDArray exp3 = Nd4j.create(new double[] {10, 10, 10, 10});
Transforms.max(x, 10, false);
assertEquals(x, exp3);
Transforms.min(x, Nd4j.EPS_THRESHOLD, false);
assertEquals(x, exp2);
}
示例3: testPinnedScalarMax
import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
@Test
public void testPinnedScalarMax() throws Exception {
// simple way to stop test if we're not on CUDA backend here
assertEquals("JcublasLevel1", Nd4j.getBlasWrapper().level1().getClass().getSimpleName());
INDArray array1 = Nd4j.create(new float[]{1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f, 1.01f});
INDArray array2 = Nd4j.create(new float[]{2.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f});
INDArray max = Transforms.max(array2, 0.5f, true);
System.out.println("Max result: " + max);
assertEquals(2.0f, array2.getFloat(0), 0.01f);
assertEquals(1.0f, array2.getFloat(1), 0.01f);
}
示例4: testArrayMinMax
import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
@Test
public void testArrayMinMax() {
INDArray x = Nd4j.create(new double[] {1, 3, 5, 7});
INDArray y = Nd4j.create(new double[] {2, 2, 6, 6});
INDArray xCopy = x.dup();
INDArray yCopy = y.dup();
INDArray expMax = Nd4j.create(new double[] {2, 3, 6, 7});
INDArray expMin = Nd4j.create(new double[] {1, 2, 5, 6});
INDArray z1 = Transforms.max(x, y, true);
INDArray z2 = Transforms.min(x, y, true);
assertEquals(expMax, z1);
assertEquals(expMin, z2);
// Assert that x was not modified
assertEquals(xCopy, x);
Transforms.max(x, y, false);
// Assert that x was modified
assertEquals(expMax, x);
// Assert that y was not modified
assertEquals(yCopy, y);
// Reset the modified x
x = xCopy.dup();
Transforms.min(x, y, false);
// Assert that X was modified
assertEquals(expMin, x);
// Assert that y was not modified
assertEquals(yCopy, y);
}
示例5: DistributionStats
import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
/**
* @param mean row vector of means
* @param std row vector of standard deviations
*/
public DistributionStats(@NonNull INDArray mean, @NonNull INDArray std) {
Transforms.max(std, Nd4j.EPS_THRESHOLD, false);
if (std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
logger.info("API_INFO: Std deviation found to be zero. Transform will round up to epsilon to avoid nans.");
}
this.mean = mean;
this.std = std;
}
示例6: add
import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
/**
* Add rows of data to the statistics
*
* @param data the matrix containing multiple rows of data to include
* @param mask (optionally) the mask of the data, useful for e.g. time series
*/
public MinMaxStats.Builder add(@NonNull INDArray data, INDArray mask) {
data = DataSetUtil.tailor2d(data, mask);
if (data == null) {
// Nothing to add. Either data is empty or completely masked. Just skip it, otherwise we will get
// null pointer exceptions.
return this;
}
INDArray tad = data.javaTensorAlongDimension(0, 0);
INDArray batchMin = data.min(0);
INDArray batchMax = data.max(0);
if (!Arrays.equals(batchMin.shape(), batchMax.shape()))
throw new IllegalStateException(
"Data min and max must be same shape. Likely a bug in the operation changing the input?");
if (runningLower == null) {
// First batch
// Create copies because min and max are views to the same data set, which will cause problems with the
// side effects of Transforms.min and Transforms.max
runningLower = batchMin.dup();
runningUpper = batchMax.dup();
} else {
// Update running bounds
Transforms.min(runningLower, batchMin, false);
Transforms.max(runningUpper, batchMax, false);
}
return this;
}
示例7: scoreArray
import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
/**
*
* @param labels
* @param preOutput
* @param activationFn
* @param mask
* @return
*/
public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if (labels.size(1) != preOutput.size(1)) {
throw new IllegalArgumentException(
"Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer"
+ " number of outputs (nOut = " + preOutput.size(1) + ") ");
}
/*
mean of -(y.dot(yhat)/||y||*||yhat||)
*/
//INDArray postOutput = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(activationFn, preOutput.dup()));
INDArray postOutput = activationFn.getActivation(preOutput.dup(), true);
INDArray yhatmag = postOutput.norm2(1);
INDArray ymag = labels.norm2(1);
yhatmag = Transforms.max(yhatmag, Nd4j.EPS_THRESHOLD, false);
ymag = Transforms.max(ymag, Nd4j.EPS_THRESHOLD, false);
INDArray scoreArr = postOutput.mul(labels);
scoreArr.diviColumnVector(yhatmag);
scoreArr.diviColumnVector(ymag);
if (mask != null) {
if (!mask.isColumnVector()) {
//Per-output masking doesn't really make sense for cosine proximity
throw new UnsupportedOperationException("Expected column vector mask array for LossCosineProximity."
+ " Got mask array with shape " + Arrays.toString(mask.shape())
+ "; per-output masking is not " + "supported for LossCosineProximity");
}
scoreArr.muliColumnVector(mask);
}
return scoreArr.muli(-1);
}
示例8: computeGradient
import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
@Override
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if (labels.size(1) != preOutput.size(1)) {
throw new IllegalArgumentException(
"Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer"
+ " number of outputs (nOut = " + preOutput.size(1) + ") ");
}
INDArray yhat = activationFn.getActivation(preOutput.dup(), true);
INDArray yL2norm = labels.norm2(1);
INDArray yhatL2norm = yhat.norm2(1);
INDArray yhatL2normSq = yhatL2norm.mul(yhatL2norm);
//Note: This is not really the L1 norm since I am not taking abs values
INDArray yhatDotyL1norm = labels.mul(yhat).sum(1);
INDArray dLda = labels.mulColumnVector(yhatL2normSq);
dLda.subi(yhat.mulColumnVector(yhatDotyL1norm));
// transform vals to avoid nans before div
yL2norm = Transforms.max(yL2norm, Nd4j.EPS_THRESHOLD, false);
yhatL2norm = Transforms.max(yhatL2norm, Nd4j.EPS_THRESHOLD, false);
yhatL2normSq = Transforms.max(yhatL2normSq, Nd4j.EPS_THRESHOLD, false);
dLda.diviColumnVector(yL2norm);
dLda.diviColumnVector(yhatL2norm.mul(yhatL2normSq));
dLda.muli(-1);
//dL/dz
INDArray gradients = activationFn.backprop(preOutput, dLda).getFirst(); //TODO loss functions with params
if (mask != null) {
gradients.muliColumnVector(mask);
}
return gradients;
}
示例9: map
import org.nd4j.linalg.ops.transforms.Transforms; //导入方法依赖的package包/类
@Override
public NDArrayWritable map(Writable w) {
if (!(w instanceof NDArrayWritable)) {
throw new IllegalArgumentException("Input writable is not an NDArrayWritable: is " + w.getClass());
}
//Make a copy - can't always assume that the original INDArray won't be used again in the future
NDArrayWritable n = ((NDArrayWritable) w);
INDArray a = n.get().dup();
switch (mathOp) {
case Add:
a.addi(scalar);
break;
case Subtract:
a.subi(scalar);
break;
case Multiply:
a.muli(scalar);
break;
case Divide:
a.divi(scalar);
break;
case Modulus:
throw new UnsupportedOperationException(mathOp + " is not supported for NDArrayWritable");
case ReverseSubtract:
a.rsubi(scalar);
break;
case ReverseDivide:
a.rdivi(scalar);
break;
case ScalarMin:
Transforms.min(a, scalar, false);
break;
case ScalarMax:
Transforms.max(a, scalar, false);
break;
default:
throw new UnsupportedOperationException("Unknown or not supported op: " + mathOp);
}
//To avoid threading issues...
Nd4j.getExecutioner().commit();
return new NDArrayWritable(a);
}