本文整理汇总了Java中org.nd4j.linalg.api.ndarray.INDArray类的典型用法代码示例。如果您正苦于以下问题:Java INDArray类的具体用法?Java INDArray怎么用?Java INDArray使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
INDArray类属于org.nd4j.linalg.api.ndarray包,在下文中一共展示了INDArray类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: getConceptVector
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
public INDArray getConceptVector(Concept c) {
Tokenizer tok = SimpleTokenizer.INSTANCE;
List<INDArray> vectors = new ArrayList<INDArray>();
int countUnk = 0;
for (String word : tok.tokenize(c.name.toLowerCase().trim())) {
if (wordVectors.hasWord(word))
vectors.add(wordVectors.getWordVectorMatrix(word));
else {
vectors.add(unkVector);
countUnk++;
}
}
if (vectors.size() == countUnk)
return null; // all tokens unknown
INDArray allVectors = Nd4j.vstack(vectors);
// sum or mean is irrelevant for cosine similarity
INDArray conceptVector = allVectors.mean(0);
return conceptVector;
}
示例2: loadFeaturesFromString
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
* Used post training to convert a String to a features INDArray that can be passed to the network output method
*
* @param reviewContents Contents of the review to vectorize
* @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
* @return Features array for the given input String
*/
public INDArray loadFeaturesFromString(String reviewContents, int maxLength){
List<String> tokens = tokenizerFactory.create(reviewContents).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for(String t : tokens ){
if(wordVectors.hasWord(t)) tokensFiltered.add(t);
}
int outputLength = Math.max(maxLength,tokensFiltered.size());
INDArray features = Nd4j.create(1, vectorSize, outputLength);
for( int j=0; j<tokens.size() && j<maxLength; j++ ){
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
}
return features;
}
示例3: fetch
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
public void fetch(int numExamples) {
float[][] featureData = new float[numExamples][0];
float[][] labelData = new float[numExamples][0];
int examplesRead = 0;
for (; examplesRead < numExamples; examplesRead++) {
if (cursor + examplesRead >= m_allFileNames.size()) {
break;
}
Entry<String, String> entry = m_allFileNames.get(cursor + examplesRead);
featureData[examplesRead] = imageFileNameToMnsitFormat(entry.getValue());
labelData[examplesRead] = toLabelArray(entry.getKey());
}
cursor += examplesRead;
INDArray features = Nd4j.create(featureData);
INDArray labels = Nd4j.create(labelData);
curr = new DataSet(features, labels);
}
示例4: getPar2Hier
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
* transforms paragraph vectors into hierarchical vectors
* @param iterator iterator over docs
* @param lookupTable the paragraph vector table
* @param labels the labels
* @param k the no. of centroids
* @return a map doc->hierarchical vector
*/
static Map<String, INDArray> getPar2Hier(LabelAwareIterator iterator,
WeightLookupTable<VocabWord> lookupTable,
List<String> labels, int k, Method method) {
Collections.sort(labels);
LabelsSource labelsSource = iterator.getLabelsSource();
PatriciaTrie<String> trie = new PatriciaTrie<>();
for (String label : labels) {
trie.put(label, label);
}
Map<String, INDArray> hvs = new TreeMap<>();
// for each doc
for (String node : labelsSource.getLabels()) {
Par2HierUtils.getPar2HierVector(lookupTable, trie, node, k, hvs, method);
}
return hvs;
}
示例5: computeSimilarity
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
public double computeSimilarity(Concept c1, Concept c2) {
if (c1.name.toLowerCase().equals(c2.name.toLowerCase()))
return 1;
if (wordVectors == null) {
this.loadWordVectors(type, dimension);
int[] shape = wordVectors.lookupTable().getWeights().shape();
System.out.println("word embeddings loaded, " + shape[0] + " " + shape[1]);
}
INDArray cVector1 = this.getConceptVector(c1);
INDArray cVector2 = this.getConceptVector(c2);
if (cVector1 == null || cVector2 == null)
return Double.NaN;
double dist = Transforms.cosineSim(cVector1, cVector2);
if (Double.isNaN(dist))
System.err.println("Embedding NaN");
return dist;
}
示例6: fromText
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
private static Pair<List<String>, INDArray> fromText(String wordFilePath) throws IOException {
BufferedReader reader = new BufferedReader(Common.asReaderUTF8Lenient(new FileInputStream(new File(wordFilePath))));
String fstLine = reader.readLine();
int vocabSize = Integer.parseInt(fstLine.split(" ")[0]);
int layerSize = Integer.parseInt(fstLine.split(" ")[1]);
List<String> wordVocab = Lists.newArrayList();
INDArray wordVectors = Nd4j.create(vocabSize, layerSize);
int n = 1;
String line;
while ((line = reader.readLine()) != null) {
String[] values = line.split(" ");
wordVocab.add(values[0]);
Preconditions.checkArgument(layerSize == values.length - 1, "For file '%s', on line %s, layer size is %s, but found %s values in the word vector",
wordFilePath, n, layerSize, values.length - 1); // Sanity check
for (int d = 1; d < values.length; d++) wordVectors.putScalar(n - 1, d - 1, Float.parseFloat(values[d]));
n++;
}
return new Pair<>(wordVocab, wordVectors);
}
示例7: getTrainingData
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
public FederatedDataSet getTrainingData() {
Random rand = new Random(seed);
double[] sum = new double[N_SAMPLES];
double[] input1 = new double[N_SAMPLES];
double[] input2 = new double[N_SAMPLES];
for (int i = 0; i < N_SAMPLES; i++) {
input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
sum[i] = input1[i] + input2[i];
}
INDArray inputNDArray1 = Nd4j.create(input1, new int[]{N_SAMPLES, 1});
INDArray inputNDArray2 = Nd4j.create(input2, new int[]{N_SAMPLES, 1});
INDArray inputNDArray = Nd4j.hstack(inputNDArray1, inputNDArray2);
INDArray outPut = Nd4j.create(sum, new int[]{N_SAMPLES, 1});
DataSet dataSet = new DataSet(inputNDArray, outPut);
dataSet.shuffle();
return new FederatedDataSetImpl(dataSet);
}
示例8: getTestData
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
public FederatedDataSet getTestData() {
Random rand = new Random(seed);
int numSamples = N_SAMPLES/10;
double[] sum = new double[numSamples];
double[] input1 = new double[numSamples];
double[] input2 = new double[numSamples];
for (int i = 0; i < numSamples; i++) {
input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
sum[i] = input1[i] + input2[i];
}
INDArray inputNDArray1 = Nd4j.create(input1, new int[]{numSamples, 1});
INDArray inputNDArray2 = Nd4j.create(input2, new int[]{numSamples, 1});
INDArray inputNDArray = Nd4j.hstack(inputNDArray1, inputNDArray2);
INDArray outPut = Nd4j.create(sum, new int[]{numSamples, 1});
return new FederatedDataSetImpl(new DataSet(inputNDArray, outPut));
}
示例9: toCsv
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
private String toCsv(DataSetIterator it, List<Integer> labels, int[] shape) {
if (it.numExamples() != labels.size()) {
throw new IllegalStateException(
String.format("numExamples == %d != labels.size() == %d",
it.numExamples(), labels.size()));
}
StringBuffer sb = new StringBuffer();
int l = 0;
while (it.hasNext()) {
INDArray features = it.next(1).getFeatures();
if (!(Arrays.equals(features.shape(), shape))) {
throw new IllegalStateException(String.format("wrong shape: got %s, expected",
Arrays.toString(features.shape()), Arrays.toString(shape)));
}
// Prepend the label
sb.append(labels.get(l)).append(": ");
l++;
for (int i=0; i<features.columns(); i++) {
sb.append(features.getColumn(i));
if (i < features.columns()-1) {
sb.append(", ");
}
}
sb.append("\n");
}
return sb.toString();
}
示例10: doPredict
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
@Override
protected Object doPredict(List<String> line) {
try {
ListStringSplit input = new ListStringSplit(Collections.singletonList(line));
ListStringRecordReader rr = new ListStringRecordReader();
rr.initialize(input);
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1);
DataSet ds = iterator.next();
INDArray prediction = model.output(ds.getFeatures());
DataType outputType = types.get(this.output);
switch (outputType) {
case _float : return prediction.getDouble(0);
case _class: {
int numClasses = 2;
double max = 0;
int maxIndex = -1;
for (int i=0;i<numClasses;i++) {
if (prediction.getDouble(i) > max) {maxIndex = i; max = prediction.getDouble(i);}
}
return maxIndex;
// return prediction.getInt(0,1); // numberOfClasses
}
default: throw new IllegalArgumentException("Output type not yet supported "+outputType);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
示例11: getScores
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
* This method accepts vector, that represents any document,
* and returns distances between this document, and previously trained categories
* @return
*/
public List<Pair<String, Double>> getScores(@NonNull INDArray vector) {
List<Pair<String, Double>> result = new ArrayList<>();
for (String label: labelsUsed) {
INDArray vecLabel = lookupTable.vector(label);
if (vecLabel == null) throw new IllegalStateException("Label '"+ label+"' has no known vector!");
double sim = Transforms.cosineSim(vector, vecLabel);
result.add(new Pair<String, Double>(label, sim));
}
return result;
}
示例12: add
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
public INDArray add(final Number n) {
final ArrayAnyD<N> newDelegate = this.copy();
final N val = myFactory.scalar().cast(n);
final UnaryFunction<N> modifier = myFactory.function().add().second(val);
newDelegate.modifyAll(modifier);
return new ArrayND<>(myFactory, newDelegate);
}
示例13: State
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
* @param velocityVector The velocity vector to extract the velocity components from.
* @param flatKick The flat-kick strength as a percentage (between 0 and 1).
* @param chipKick The chip-kick strength as a percentage (between 0 and 1).
* @param dribblerSpin The dribbler spin as a percentage with the sign indicating the
* direction of spin (between -1 and 1)
*/
public State(
final INDArray velocityVector,
final float flatKick,
final float chipKick,
final float dribblerSpin
) {
this(
velocityVector.getFloat(0, 0),
velocityVector.getFloat(1, 0),
velocityVector.getFloat(2, 0),
flatKick,
chipKick,
dribblerSpin);
}
示例14: loadWord2VecModel
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/** @return {@link Word2Vec} */
public static Word2Vec loadWord2VecModel (String wordFilePath, boolean binary) {
Word2Vec model = null;
try {
Pair<List<String>, INDArray> pair;
if (binary) pair = fromBinary(wordFilePath);
else pair = fromText(wordFilePath);
model = new Word2Vec(pair.getValue().columns(), pair.getKey(), pair.getValue(), true);
} catch (IOException e) {
e.printStackTrace();
}
return model;
}
示例15: getPotential
import org.nd4j.linalg.api.ndarray.INDArray; //导入依赖的package包/类
/**
* See <a href="http://www.wolframalpha.com/input/?i=-1*(x*Cos%5BA%5D+%2B+y*Sin%5BA%5D)">this equation.</a>
*
* @param positionVector The position vector at which to compute the potential.
*/
@Override
public double getPotential(final INDArray positionVector) {
return positionVector
.mul(this.multiplier)
.sumNumber().doubleValue();
}