当前位置: 首页>>代码示例>>Java>>正文


Java INDArrayIndex类代码示例

本文整理汇总了Java中org.nd4j.linalg.indexing.INDArrayIndex的典型用法代码示例。如果您正苦于以下问题:Java INDArrayIndex类的具体用法?Java INDArrayIndex怎么用?Java INDArrayIndex使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


INDArrayIndex类属于org.nd4j.linalg.indexing包,在下文中一共展示了INDArrayIndex类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: loadFeaturesFromString

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/**
 * Used post training to convert a String to a features INDArray that can be passed to the network output method
 *
 * @param reviewContents Contents of the review to vectorize
 * @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
 * @return Features array for the given input String
 */
public INDArray loadFeaturesFromString(String reviewContents, int maxLength){
	List<String> tokens = tokenizerFactory.create(reviewContents).getTokens();
	List<String> tokensFiltered = new ArrayList<>();
	for(String t : tokens ){
		if(wordVectors.hasWord(t)) tokensFiltered.add(t);
	}
	int outputLength = Math.max(maxLength,tokensFiltered.size());

	INDArray features = Nd4j.create(1, vectorSize, outputLength);

	for( int j=0; j<tokens.size() && j<maxLength; j++ ){
		String token = tokens.get(j);
		INDArray vector = wordVectors.getWordVectorMatrix(token);
		features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
	}

	return features;
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:26,代码来源:SentimentExampleIterator.java

示例2: testResolvePointVector

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
public void testResolvePointVector() {
    INDArray arr = Nd4j.linspace(1, 4, 4);
    INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};
    INDArrayIndex[] resolved = NDArrayIndex.resolve(arr.shape(), getPoint);
    if (getPoint.length == resolved.length)
        assertArrayEquals(getPoint, resolved);
    else {
        assertEquals(2, resolved.length);
        assertTrue(resolved[0] instanceof PointIndex);
        assertEquals(0, resolved[0].current());
        assertTrue(resolved[1] instanceof PointIndex);
        assertEquals(1, resolved[1].current());
    }

}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:17,代码来源:NDArrayIndexResolveTests.java

示例3: testIndexPointInterval

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointInterval() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.point(1);
    INDArrayIndex y = NDArrayIndex.interval(1, 2, true);
    INDArrayIndex z = NDArrayIndex.point(1);
    INDArray value = Nd4j.ones(1, 2);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f2) && !zeros.toString().equals(f1))
        assertEquals(f2, zeros.toString());

}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:ShapeResolutionTestsC.java

示例4: testIndexPointAll

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointAll() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.point(1);
    INDArrayIndex y = NDArrayIndex.all();
    INDArrayIndex z = NDArrayIndex.point(1);
    INDArray value = Nd4j.ones(1, 3);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + "  [[0,00,1,00,0,00]\n"
                    + " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + "  [[0.00,1.00,0.00]\n"
                    + " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
        assertEquals(f2, zeros.toString());
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java

示例5: testIndexIntervalAll

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexIntervalAll() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.interval(0, 1, true);
    INDArrayIndex y = NDArrayIndex.all();
    INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
    INDArray value = Nd4j.ones(2, 6);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + "  [[0,00,1,00,1,00]\n"
                    + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + "  [[0.00,1.00,1.00]\n"
                    + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
        assertEquals(f2, zeros.toString());
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java

示例6: testIndexPointIntervalAll

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointIntervalAll() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.point(1);
    INDArrayIndex y = NDArrayIndex.all();
    INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
    INDArray value = Nd4j.ones(3, 2);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + "  [[0,00,1,00,1,00]\n"
                    + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + "  [[0.00,1.00,1.00]\n"
                    + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
        assertEquals(f2, zeros.toString());
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java

示例7: mergePerOutputMasks2d

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
public static INDArray mergePerOutputMasks2d(int[] outShape, INDArray[] arrays, INDArray[] masks) {
    int[] numExamplesPerArr = new int[arrays.length];
    for (int i = 0; i < numExamplesPerArr.length; i++) {
        numExamplesPerArr[i] = arrays[i].size(0);
    }

    INDArray outMask = Nd4j.ones(outShape); //Initialize to 'all present' (1s)

    int rowsSoFar = 0;
    for (int i = 0; i < masks.length; i++) {
        int thisRows = numExamplesPerArr[i]; //Mask itself may be null -> all present, but may include multiple examples
        if (masks[i] == null) {
            continue;
        }

        outMask.put(new INDArrayIndex[] {NDArrayIndex.interval(rowsSoFar, rowsSoFar + thisRows),
                        NDArrayIndex.all()}, masks[i]);
        rowsSoFar += thisRows;
    }
    return outMask;
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:DataSetUtil.java

示例8: toFlattened

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/**
 * Returns a vector with all of the elements in every nd array
 * equal to the sum of the lengths of the ndarrays
 *
 * @param matrices the ndarrays to getFloat a flattened representation of
 * @return the flattened ndarray
 */
@Override
public INDArray toFlattened(Collection<INDArray> matrices) {
    int length = 0;
    for (INDArray m : matrices)
        length += m.length();
    INDArray ret = Nd4j.create(1, length);
    int linearIndex = 0;
    for (INDArray d : matrices) {
        ret.put(new INDArrayIndex[] {NDArrayIndex.interval(linearIndex, linearIndex + d.length())}, d);
        linearIndex += d.length();
    }

    return ret;

}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:BaseNDArrayFactory.java

示例9: backpropGradient

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
    INDArray newEps = Nd4j.create(origOutputShape, 'f');
    if(lastTimeStepIdxs == null){
        //no mask case
        newEps.put(new INDArrayIndex[]{all(), all(), point(origOutputShape[2]-1)}, epsilon);
    } else {
        INDArrayIndex[] arr = new INDArrayIndex[]{null, all(), null};
        //TODO probably possible to optimize this with reshape + scatter ops...
        for( int i=0; i<lastTimeStepIdxs.length; i++ ){
            arr[0] = point(i);
            arr[2] = point(lastTimeStepIdxs[i]);
            newEps.put(arr, epsilon.getRow(i));
        }
    }
    return underlying.backpropGradient(newEps);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:LastTimeStepLayer.java

示例10: preOutput

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
public INDArray preOutput(boolean training) {
    INDArray b = getParam(DefaultParamInitializer.BIAS_KEY);
    INDArray W = getParam(DefaultParamInitializer.WEIGHT_KEY);

    if ( input.columns() != W.columns()) {
        throw new DL4JInvalidInputException(
                "Input size (" + input.columns() + " columns; shape = " + Arrays.toString(input.shape())
                        + ") is invalid: does not match layer input size (layer # inputs = "
                        + W.shapeInfoToString() + ") " + layerId());
    }

    applyDropOutIfNecessary(training);

    INDArray ret = Nd4j.zeros(input.rows(),input.columns());

    for(int row = 0; row<input.rows();row++){
        ret.put(new INDArrayIndex[]{NDArrayIndex.point(row), NDArrayIndex.all()},input.getRow(row).mul(W).addRowVector(b));
    }

    if (maskArray != null) {
        applyMask(ret);
    }

    return ret;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:ElementWiseMultiplicationLayer.java

示例11: doBackward

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {

    //Allocate the appropriate sized array:
    INDArray epsilonsOut = Nd4j.create(fwdPassShape);

    if (fwdPassTimeSteps == null) {
        //Last time step for all examples
        epsilonsOut.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(),
                        NDArrayIndex.point(fwdPassShape[2] - 1)}, epsilon);
    } else {
        //Different time steps were extracted for each example
        for (int i = 0; i < fwdPassTimeSteps.length; i++) {
            epsilonsOut.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.all(),
                            NDArrayIndex.point(fwdPassTimeSteps[i])}, epsilon.getRow(i));
        }
    }
    return new Pair<>(null, new INDArray[] {epsilonsOut});
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:LastTimeStepVertex.java

示例12: doBackward

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
    if (!canDoBackward())
        throw new IllegalStateException("Cannot do backward pass: error not set");

    INDArray out = Nd4j.zeros(forwardShape);
    int start = from * step;
    int end = (from + 1) * step;

    switch (forwardShape.length) {
        case 2:
            out.put(new INDArrayIndex[] {NDArrayIndex.interval(start, end), NDArrayIndex.all()}, epsilon);
            break;
        case 3:
            out.put(new INDArrayIndex[] {NDArrayIndex.interval(start, end), NDArrayIndex.all(), NDArrayIndex.all()},
                            epsilon);
            break;
        case 4:
            out.put(new INDArrayIndex[] {NDArrayIndex.interval(start, end), NDArrayIndex.all(), NDArrayIndex.all(),
                            NDArrayIndex.all()}, epsilon);
            break;
        default:
            throw new RuntimeException("Invalid activation rank"); //Should never happen
    }
    return new Pair<>(null, new INDArray[] {out});
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:UnstackVertex.java

示例13: doBackward

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
@Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt) {
    if (!canDoBackward())
        throw new IllegalStateException("Cannot do backward pass: error not set");

    INDArray out = Nd4j.zeros(forwardShape);
    switch (forwardShape.length) {
        case 2:
            out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true)}, epsilon);
            break;
        case 3:
            out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true),
                            NDArrayIndex.all()}, epsilon);
            break;
        case 4:
            out.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(from, to, true),
                            NDArrayIndex.all(), NDArrayIndex.all()}, epsilon);
            break;
        default:
            throw new RuntimeException("Invalid activation rank"); //Should never happen
    }
    return new Pair<>(null, new INDArray[] {out});
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:24,代码来源:SubsetVertex.java

示例14: putExample

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
private void putExample(INDArray arr, INDArray singleExample, int exampleIdx) {
    switch (arr.rank()) {
        case 2:
            arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all()}, singleExample);
            break;
        case 3:
            arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all()},
                            singleExample);
            break;
        case 4:
            arr.put(new INDArrayIndex[] {NDArrayIndex.point(exampleIdx), NDArrayIndex.all(), NDArrayIndex.all(),
                            NDArrayIndex.all()}, singleExample);
            break;
        default:
            throw new RuntimeException("Unexpected rank: " + arr.rank());
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:RecordReaderMultiDataSetIterator.java

示例15: loadSingleSentence

import org.nd4j.linalg.indexing.INDArrayIndex; //导入依赖的package包/类
/** Generally used post training time to load a single sentence for predictions */
public INDArray loadSingleSentence(String sentence) {
  List<String> tokens = tokenizeSentence(sentence);

  int[] featuresShape = new int[] {1, 1, 0, 0};
  if (sentencesAlongHeight) {
    featuresShape[2] = Math.min(maxSentenceLength, tokens.size());
    featuresShape[3] = wordVectorSize;
  } else {
    featuresShape[2] = wordVectorSize;
    featuresShape[3] = Math.min(maxSentenceLength, tokens.size());
  }

  INDArray features = Nd4j.create(featuresShape);
  int length = (sentencesAlongHeight ? featuresShape[2] : featuresShape[3]);
  for (int i = 0; i < length; i++) {
    INDArray vector = getVector(tokens.get(i));

    INDArrayIndex[] indices = new INDArrayIndex[4];
    indices[0] = NDArrayIndex.point(0);
    indices[1] = NDArrayIndex.point(0);
    if (sentencesAlongHeight) {
      indices[2] = NDArrayIndex.point(i);
      indices[3] = NDArrayIndex.all();
    } else {
      indices[2] = NDArrayIndex.all();
      indices[3] = NDArrayIndex.point(i);
    }

    features.put(indices, vector);
  }

  return features;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:35,代码来源:CnnSentenceDataSetIterator.java


注:本文中的org.nd4j.linalg.indexing.INDArrayIndex类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。