当前位置: 首页>>代码示例>>Java>>正文


Java INDArray类代码示例

本文整理汇总了Java中org.nd4j.linalg.api.ndarray.INDArray的典型用法代码示例。如果您正苦于以下问题:Java INDArray类的具体用法?Java INDArray怎么用?Java INDArray使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


INDArray类属于org.nd4j.linalg.api.ndarray包,在下文中一共展示了INDArray类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: getConceptVector

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
public INDArray getConceptVector(Concept c) {

		Tokenizer tok = SimpleTokenizer.INSTANCE;

		List<INDArray> vectors = new ArrayList<INDArray>();
		int countUnk = 0;
		for (String word : tok.tokenize(c.name.toLowerCase().trim())) {
			if (wordVectors.hasWord(word))
				vectors.add(wordVectors.getWordVectorMatrix(word));
			else {
				vectors.add(unkVector);
				countUnk++;
			}
		}
		if (vectors.size() == countUnk)
			return null; // all tokens unknown
		INDArray allVectors = Nd4j.vstack(vectors);

		// sum or mean is irrelevant for cosine similarity
		INDArray conceptVector = allVectors.mean(0);

		return conceptVector;
	}
 
开发者ID:UKPLab,项目名称:ijcnlp2017-cmaps,代码行数:24,代码来源:WordEmbeddingDistance.java

示例2: loadFeaturesFromString

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
 * Used post training to convert a String to a features INDArray that can be passed to the network output method
 *
 * @param reviewContents Contents of the review to vectorize
 * @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
 * @return Features array for the given input String
 */
public INDArray loadFeaturesFromString(String reviewContents, int maxLength){
	List<String> tokens = tokenizerFactory.create(reviewContents).getTokens();
	List<String> tokensFiltered = new ArrayList<>();
	for(String t : tokens ){
		if(wordVectors.hasWord(t)) tokensFiltered.add(t);
	}
	int outputLength = Math.max(maxLength,tokensFiltered.size());

	INDArray features = Nd4j.create(1, vectorSize, outputLength);

	for( int j=0; j<tokens.size() && j<maxLength; j++ ){
		String token = tokens.get(j);
		INDArray vector = wordVectors.getWordVectorMatrix(token);
		features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
	}

	return features;
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:26,代码来源:SentimentExampleIterator.java

示例3: fetch

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
public void fetch(int numExamples) {
	float[][] featureData = new float[numExamples][0];
	float[][] labelData = new float[numExamples][0];

	int examplesRead = 0;

	for (; examplesRead < numExamples; examplesRead++) {
		if (cursor + examplesRead >= m_allFileNames.size()) {
			break;
		}
		Entry<String, String> entry = m_allFileNames.get(cursor + examplesRead);

		featureData[examplesRead] = imageFileNameToMnsitFormat(entry.getValue());
		labelData[examplesRead] = toLabelArray(entry.getKey());
	}
	cursor += examplesRead;

	INDArray features = Nd4j.create(featureData);
	INDArray labels = Nd4j.create(labelData);
	curr = new DataSet(features, labels);
}
 
开发者ID:braeunlich,项目名称:anagnostes,代码行数:23,代码来源:NumbersDataFetcher.java

示例4: getPar2Hier

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
 * transforms paragraph vectors into hierarchical vectors
 * @param iterator iterator over docs
 * @param lookupTable the paragraph vector table
 * @param labels the labels
 * @param k the no. of centroids
 * @return a map doc->hierarchical vector
 */
static Map<String, INDArray> getPar2Hier(LabelAwareIterator iterator,
                                         WeightLookupTable<VocabWord> lookupTable,
                                         List<String> labels, int k, Method method) {
  Collections.sort(labels);
  LabelsSource labelsSource = iterator.getLabelsSource();
  PatriciaTrie<String> trie = new PatriciaTrie<>();
  for (String label : labels) {
    trie.put(label, label);
  }

  Map<String, INDArray> hvs = new TreeMap<>();
  // for each doc
  for (String node : labelsSource.getLabels()) {
    Par2HierUtils.getPar2HierVector(lookupTable, trie, node, k, hvs, method);
  }
  return hvs;
}
 
开发者ID:tteofili,项目名称:par2hier,代码行数:26,代码来源:Par2HierUtils.java

示例5: computeSimilarity

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
public double computeSimilarity(Concept c1, Concept c2) {
	if (c1.name.toLowerCase().equals(c2.name.toLowerCase()))
		return 1;

	if (wordVectors == null) {
		this.loadWordVectors(type, dimension);
		int[] shape = wordVectors.lookupTable().getWeights().shape();
		System.out.println("word embeddings loaded, " + shape[0] + " " + shape[1]);
	}

	INDArray cVector1 = this.getConceptVector(c1);
	INDArray cVector2 = this.getConceptVector(c2);
	if (cVector1 == null || cVector2 == null)
		return Double.NaN;

	double dist = Transforms.cosineSim(cVector1, cVector2);

	if (Double.isNaN(dist))
		System.err.println("Embedding NaN");

	return dist;
}
 
开发者ID:UKPLab,项目名称:ijcnlp2017-cmaps,代码行数:24,代码来源:WordEmbeddingDistance.java

示例6: fromText

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
private static Pair<List<String>, INDArray> fromText(String wordFilePath) throws IOException {
	BufferedReader reader = new BufferedReader(Common.asReaderUTF8Lenient(new FileInputStream(new File(wordFilePath))));
	String fstLine = reader.readLine();
	int vocabSize = Integer.parseInt(fstLine.split(" ")[0]);
	int layerSize = Integer.parseInt(fstLine.split(" ")[1]);
	List<String> wordVocab = Lists.newArrayList();
	INDArray wordVectors = Nd4j.create(vocabSize, layerSize);
	int n = 1;
	String line;
	while ((line = reader.readLine()) != null) {
		String[] values = line.split(" ");
		wordVocab.add(values[0]);
		Preconditions.checkArgument(layerSize == values.length - 1, "For file '%s', on line %s, layer size is %s, but found %s values in the word vector",
				wordFilePath, n, layerSize, values.length - 1); // Sanity check
		for (int d = 1; d < values.length; d++) wordVectors.putScalar(n - 1, d - 1, Float.parseFloat(values[d]));
		n++;
	}
	return new Pair<>(wordVocab, wordVectors);
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:20,代码来源:WordVectorSerializer.java

示例7: getTrainingData

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
public FederatedDataSet getTrainingData() {
    Random rand = new Random(seed);
    double[] sum = new double[N_SAMPLES];
    double[] input1 = new double[N_SAMPLES];
    double[] input2 = new double[N_SAMPLES];
    for (int i = 0; i < N_SAMPLES; i++) {
        input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        sum[i] = input1[i] + input2[i];
    }
    INDArray inputNDArray1 = Nd4j.create(input1, new int[]{N_SAMPLES, 1});
    INDArray inputNDArray2 = Nd4j.create(input2, new int[]{N_SAMPLES, 1});
    INDArray inputNDArray = Nd4j.hstack(inputNDArray1, inputNDArray2);
    INDArray outPut = Nd4j.create(sum, new int[]{N_SAMPLES, 1});
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    dataSet.shuffle();
    return new FederatedDataSetImpl(dataSet);
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:20,代码来源:SumDataSource.java

示例8: getTestData

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
public FederatedDataSet getTestData() {
    Random rand = new Random(seed);
    int numSamples = N_SAMPLES/10;
    double[] sum = new double[numSamples];
    double[] input1 = new double[numSamples];
    double[] input2 = new double[numSamples];
    for (int i = 0; i < numSamples; i++) {
        input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        sum[i] = input1[i] + input2[i];
    }
    INDArray inputNDArray1 = Nd4j.create(input1, new int[]{numSamples, 1});
    INDArray inputNDArray2 = Nd4j.create(input2, new int[]{numSamples, 1});
    INDArray inputNDArray = Nd4j.hstack(inputNDArray1, inputNDArray2);
    INDArray outPut = Nd4j.create(sum, new int[]{numSamples, 1});
    return new FederatedDataSetImpl(new DataSet(inputNDArray, outPut));
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:19,代码来源:SumDataSource.java

示例9: toCsv

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
private String toCsv(DataSetIterator it, List<Integer> labels, int[] shape) {
    if (it.numExamples() != labels.size()) {
        throw new IllegalStateException(
                String.format("numExamples == %d != labels.size() == %d",
                        it.numExamples(), labels.size()));
    }

    StringBuffer sb = new StringBuffer();
    int l = 0;

    while (it.hasNext()) {
        INDArray features = it.next(1).getFeatures();

        if (!(Arrays.equals(features.shape(), shape))) {
            throw new IllegalStateException(String.format("wrong shape: got %s, expected",
                    Arrays.toString(features.shape()), Arrays.toString(shape)));
        }

        // Prepend the label
        sb.append(labels.get(l)).append(": ");
        l++;

        for (int i=0; i<features.columns(); i++) {
            sb.append(features.getColumn(i));

            if (i < features.columns()-1) {
                sb.append(", ");
            }
        }

        sb.append("\n");
    }

    return sb.toString();
}
 
开发者ID:SkymindIO,项目名称:SKIL_CE_1.0.0_Examples,代码行数:36,代码来源:NormalizeUciData.java

示例10: doPredict

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
    protected Object doPredict(List<String> line) {
        try {
            ListStringSplit input = new ListStringSplit(Collections.singletonList(line));
            ListStringRecordReader rr = new ListStringRecordReader();
            rr.initialize(input);
            DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1);

            DataSet ds = iterator.next();
            INDArray prediction = model.output(ds.getFeatures());

            DataType outputType = types.get(this.output);
            switch (outputType) {
                case _float : return prediction.getDouble(0);
                case _class: {
                    int numClasses = 2;
                    double max = 0;
                    int maxIndex = -1;
                    for (int i=0;i<numClasses;i++) {
                        if (prediction.getDouble(i) > max) {maxIndex = i; max = prediction.getDouble(i);}
                    }
                    return maxIndex;
//                    return prediction.getInt(0,1); // numberOfClasses
                }
                default: throw new IllegalArgumentException("Output type not yet supported "+outputType);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
 
开发者ID:neo4j-contrib,项目名称:neo4j-ml-procedures,代码行数:31,代码来源:DL4JMLModel.java

示例11: getScores

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
 * This method accepts vector, that represents any document,
 * and returns distances between this document, and previously trained categories
 * @return
 */
public List<Pair<String, Double>> getScores(@NonNull INDArray vector) {
    List<Pair<String, Double>> result = new ArrayList<>();
    for (String label: labelsUsed) {
        INDArray vecLabel = lookupTable.vector(label);
        if (vecLabel == null) throw new IllegalStateException("Label '"+ label+"' has no known vector!");

        double sim = Transforms.cosineSim(vector, vecLabel);
        result.add(new Pair<String, Double>(label, sim));
    }
    return result;
}
 
开发者ID:tteofili,项目名称:par2hier,代码行数:17,代码来源:LabelSeeker.java

示例12: add

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
public INDArray add(final Number n) {
    final ArrayAnyD<N> newDelegate = this.copy();
    final N val = myFactory.scalar().cast(n);
    final UnaryFunction<N> modifier = myFactory.function().add().second(val);
    newDelegate.modifyAll(modifier);
    return new ArrayND<>(myFactory, newDelegate);
}
 
开发者ID:optimatika,项目名称:ojAlgo-extensions,代码行数:8,代码来源:ArrayND.java

示例13: State

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
 * @param velocityVector The velocity vector to extract the velocity components from.
 * @param flatKick       The flat-kick strength as a percentage (between 0 and 1).
 * @param chipKick       The chip-kick strength as a percentage (between 0 and 1).
 * @param dribblerSpin   The dribbler spin as a percentage with the sign indicating the
 *                       direction of spin (between -1 and 1)
 */
public State(
    final INDArray velocityVector,
    final float flatKick,
    final float chipKick,
    final float dribblerSpin
) {
  this(
      velocityVector.getFloat(0, 0),
      velocityVector.getFloat(1, 0),
      velocityVector.getFloat(2, 0),
      flatKick,
      chipKick,
      dribblerSpin);
}
 
开发者ID:delta-leonis,项目名称:subra,代码行数:22,代码来源:PlayerCommand.java

示例14: loadWord2VecModel

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/** @return {@link Word2Vec} */
public static Word2Vec loadWord2VecModel (String wordFilePath, boolean binary) {
	Word2Vec model = null;
	try {
		Pair<List<String>, INDArray> pair;
		if (binary) pair = fromBinary(wordFilePath);
		else pair = fromText(wordFilePath);
		model = new Word2Vec(pair.getValue().columns(), pair.getKey(), pair.getValue(), true);
	} catch (IOException e) {
		e.printStackTrace();
	}
	return model;
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:14,代码来源:WordVectorSerializer.java

示例15: getPotential

import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
 * See <a href="http://www.wolframalpha.com/input/?i=-1*(x*Cos%5BA%5D+%2B+y*Sin%5BA%5D)">this equation.</a>
 *
 * @param positionVector The position vector at which to compute the potential.
 */
@Override
public double getPotential(final INDArray positionVector) {
  return positionVector
      .mul(this.multiplier)
      .sumNumber().doubleValue();
}
 
开发者ID:delta-leonis,项目名称:algieba,代码行数:12,代码来源:UniformFlowPotentialField.java


注:本文中的org.nd4j.linalg.api.ndarray.INDArray类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。