本文整理汇总了Java中org.nd4j.linalg.indexing.INDArrayIndex类的典型用法代码示例。如果您正苦于以下问题:Java INDArrayIndex类的具体用法?Java INDArrayIndex怎么用?Java INDArrayIndex使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
INDArrayIndex类属于org.nd4j.linalg.indexing包,在下文中一共展示了INDArrayIndex类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: loadFeaturesFromString
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/**
* Used post training to convert a String to a features INDArray that can be passed to the network output method
*
* @param reviewContents Contents of the review to vectorize
* @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
* @return Features array for the given input String
*/
public INDArray loadFeaturesFromString(String reviewContents, int maxLength){
List<String> tokens = tokenizerFactory.create(reviewContents).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for(String t : tokens ){
if(wordVectors.hasWord(t)) tokensFiltered.add(t);
}
int outputLength = Math.max(maxLength,tokensFiltered.size());
INDArray features = Nd4j.create(1, vectorSize, outputLength);
for( int j=0; j<tokens.size() && j<maxLength; j++ ){
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
}
return features;
}
示例2: testResolvePointVector
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
public void testResolvePointVector() {
INDArray arr = Nd4j.linspace(1, 4, 4);
INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};
INDArrayIndex[] resolved = NDArrayIndex.resolve(arr.shape(), getPoint);
if (getPoint.length == resolved.length)
assertArrayEquals(getPoint, resolved);
else {
assertEquals(2, resolved.length);
assertTrue(resolved[0] instanceof PointIndex);
assertEquals(0, resolved[0].current());
assertTrue(resolved[1] instanceof PointIndex);
assertEquals(1, resolved[1].current());
}
}
示例3: testIndexPointInterval
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointInterval() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.point(1);
INDArrayIndex y = NDArrayIndex.interval(1, 2, true);
INDArrayIndex z = NDArrayIndex.point(1);
INDArray value = Nd4j.ones(1, 2);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f2) && !zeros.toString().equals(f1))
assertEquals(f2, zeros.toString());
}
示例4: testIndexPointAll
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointAll() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.point(1);
INDArrayIndex y = NDArrayIndex.all();
INDArrayIndex z = NDArrayIndex.point(1);
INDArray value = Nd4j.ones(1, 3);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,1,00,0,00]\n"
+ " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,1.00,0.00]\n"
+ " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
assertEquals(f2, zeros.toString());
}
示例5: testIndexIntervalAll
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexIntervalAll() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.interval(0, 1, true);
INDArrayIndex y = NDArrayIndex.all();
INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
INDArray value = Nd4j.ones(2, 6);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,1,00,1,00]\n"
+ " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,1.00,1.00]\n"
+ " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
assertEquals(f2, zeros.toString());
}
示例6: testIndexPointIntervalAll
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointIntervalAll() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.point(1);
INDArrayIndex y = NDArrayIndex.all();
INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
INDArray value = Nd4j.ones(3, 2);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,1,00,1,00]\n"
+ " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,1.00,1.00]\n"
+ " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
assertEquals(f2, zeros.toString());
}
示例7: mergePerOutputMasks2d
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
public static INDArray mergePerOutputMasks2d(int[] outShape, INDArray[] arrays, INDArray[] masks) {
int[] numExamplesPerArr = new int[arrays.length];
for (int i = 0; i < numExamplesPerArr.length; i++) {
numExamplesPerArr[i] = arrays[i].size(0);
}
INDArray outMask = Nd4j.ones(outShape); //Initialize to 'all present' (1s)
int rowsSoFar = 0;
for (int i = 0; i < masks.length; i++) {
int thisRows = numExamplesPerArr[i]; //Mask itself may be null -> all present, but may include multiple examples
if (masks[i] == null) {
continue;
}
outMask.put(new INDArrayIndex[] {NDArrayIndex.interval(rowsSoFar, rowsSoFar + thisRows),
NDArrayIndex.all()}, masks[i]);
rowsSoFar += thisRows;
}
return outMask;
}
示例8: toFlattened
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/**
* Returns a vector with all of the elements in every nd array
* equal to the sum of the lengths of the ndarrays
*
* @param matrices the ndarrays to getFloat a flattened representation of
* @return the flattened ndarray
*/
@Override
public INDArray toFlattened(Collection<INDArray> matrices) {
int length = 0;
for (INDArray m : matrices)
length += m.length();
INDArray ret = Nd4j.create(1, length);
int linearIndex = 0;
for (INDArray d : matrices) {
ret.put(new INDArrayIndex[] {NDArrayIndex.interval(linearIndex, linearIndex + d.length())}, d);
linearIndex += d.length();
}
return ret;
}
示例9: backpropGradient
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
INDArray newEps = Nd4j.create(origOutputShape, 'f');
if(lastTimeStepIdxs == null){
//no mask case
newEps.put(new INDArrayIndex[]{all(), all(), point(origOutputShape[2]-1)}, epsilon);
} else {
INDArrayIndex[] arr = new INDArrayIndex[]{null, all(), null};
//TODO probably possible to optimize this with reshape + scatter ops...
for( int i=0; i<lastTimeStepIdxs.length; i++ ){
arr[0] = point(i);
arr[2] = point(lastTimeStepIdxs[i]);
newEps.put(arr, epsilon.getRow(i));
}
}
return underlying.backpropGradient(newEps);
}
示例10: preOutput
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
public INDArray preOutput(boolean training) {
INDArray b = getParam(DefaultParamInitializer.BIAS_KEY);
INDArray W = getParam(DefaultParamInitializer.WEIGHT_KEY);
if ( input.columns() != W.columns()) {
throw new DL4JInvalidInputException(
"Input size (" + input.columns() + " columns; shape = " + Arrays.toString(input.shape())
+ ") is invalid: does not match layer input size (layer # inputs = "
+ W.shapeInfoToString() + ") " + layerId());
}
applyDropOutIfNecessary(training);
INDArray ret = Nd4j.zeros(input.rows(),input.columns());
for(int row = 0; row<input.rows();row++){
ret.put(new INDArrayIndex[]{NDArrayIndex.point(row), NDArrayIndex.all()},input.getRow(row).mul(W).addRowVector(b));
}
if (maskArray != null) {
applyMask(ret);
}
return ret;
}
示例11: doBackward
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
//Allocate the appropriate sized array:
INDArray epsilonsOut = Nd4j.create(fwdPassShape);
if (fwdPassTimeSteps == null) {
//Last time step for all examples
epsilonsOut.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.point(fwdPassShape[2] - 1)}, epsilon);
} else {
//Different time steps were extracted for each example
for (int i = 0; i < fwdPassTimeSteps.length; i++) {
epsilonsOut.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(fwdPassTimeSteps[i])}, epsilon.getRow(i));
}
}
return new Pair<>(null, new INDArray[] {epsilonsOut});
}
示例12: doBackward
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
if (!canDoBackward())
throw new IllegalStateException("Cannot do backward pass: error not set");
INDArray out = Nd4j.zeros(forwardShape);
int start = from * step;
int end = (from + 1) * step;
switch (forwardShape.length) {
case 2:
out.put(new INDArrayIndex[] {NDArrayIndex.interval(start, end), NDArrayIndex.all()}, epsilon);
break;
case 3:
out.put(new INDArrayIndex[] {NDArrayIndex.interval(start, end), NDArrayIndex.all(), NDArrayIndex.all()},
epsilon);
break;
case 4:
out.put(new INDArrayIndex[] {NDArrayIndex.interval(start, end), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()}, epsilon);
break;
default:
throw new RuntimeException("Invalid activation rank"); //Should never happen
}
return new Pair<>(null, new INDArray[] {out});
}
示例13: doBackward
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
if (!canDoBackward())
throw new IllegalStateException("Cannot do backward pass: error not set");
INDArray out = Nd4j.zeros(forwardShape);
switch (forwardShape.length) {
case 2:
out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true)}, epsilon);
break;
case 3:
out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true),
NDArrayIndex.all()}, epsilon);
break;
case 4:
out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true),
NDArrayIndex.all(), NDArrayIndex.all()}, epsilon);
break;
default:
throw new RuntimeException("Invalid activation rank"); //Should never happen
}
return new Pair<>(null, new INDArray[] {out});
}
示例14: putExample
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
private void putExample(INDArray arr, INDArray singleExample, int exampleIdx) {
switch (arr.rank()) {
case 2:
arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all()}, singleExample);
break;
case 3:
arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all()},
singleExample);
break;
case 4:
arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all(),
NDArrayIndex.all()}, singleExample);
break;
default:
throw new RuntimeException("Unexpected rank: " + arr.rank());
}
}
示例15: loadSingleSentence
import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/** Generally used post training time to load a single sentence for predictions */
public INDArray loadSingleSentence(String sentence) {
List<String> tokens = tokenizeSentence(sentence);
int[] featuresShape = new int[] {1, 1, 0, 0};
if (sentencesAlongHeight) {
featuresShape[2] = Math.min(maxSentenceLength, tokens.size());
featuresShape[3] = wordVectorSize;
} else {
featuresShape[2] = wordVectorSize;
featuresShape[3] = Math.min(maxSentenceLength, tokens.size());
}
INDArray features = Nd4j.create(featuresShape);
int length = (sentencesAlongHeight ? featuresShape[2] : featuresShape[3]);
for (int i = 0; i < length; i++) {
INDArray vector = getVector(tokens.get(i));
INDArrayIndex[] indices = new INDArrayIndex[4];
indices[0] = NDArrayIndex.point(0);
indices[1] = NDArrayIndex.point(0);
if (sentencesAlongHeight) {
indices[2] = NDArrayIndex.point(i);
indices[3] = NDArrayIndex.all();
} else {
indices[2] = NDArrayIndex.all();
indices[3] = NDArrayIndex.point(i);
}
features.put(indices, vector);
}
return features;
}