本文整理汇总了Java中org.nd4j.linalg.api.ndarray.INDArray.getDouble方法的典型用法代码示例。如果您正苦于以下问题:Java INDArray.getDouble方法的具体用法?Java INDArray.getDouble怎么用?Java INDArray.getDouble使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.api.ndarray.INDArray
的用法示例。
在下文中一共展示了INDArray.getDouble方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: validate
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
public void validate() {
int p = 0;
int n = 0;
for (TrainingData td : trainingDataSet) {
INDArray a = feedForward(td.input);
double r0 = a.getDouble(0);
double r1 = a.getDouble(1);
double r2 = a.getDouble(2);
int m = -1;
if (r0 > r1 && r0 > r2) m = 0;
if (r1 > r0 && r1 > r2) m = 1;
if (r2 > r0 && r2 > r1) m = 2;
if (td.output.getDouble(m) == 1) {
p++;
} else {
n++;
}
}
System.out.println("Positive: " + p);
System.out.println("Negative: " + n);
}
示例2: predict
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
public INDArray predict(INDArray x) {
INDArray y = output(x); // activate input data through learned networks
INDArray out = Nd4j.create(new double[x.rows() * nOut], new int[] { x.rows(), nOut });
for (int i = 0; i < x.rows(); i++) {
int argmax = -1;
double max = Double.MIN_VALUE;
for (int j = 0; j < nOut; j++) {
if (max < y.getDouble(i, j)) {
argmax = j;
max = y.getDouble(i, j);
}
}
out.put(i, argmax, Nd4j.scalar(1.0));
}
return out;
}
示例3: getTopN
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
* Get top N elements
*
* @param vec the vec to extract the top elements from
* @param N the number of elements to extract
* @return the indices and the sorted top N elements
*/
private List<Double> getTopN(INDArray vec, int N) {
BasicModelUtils.ArrayComparator comparator = new BasicModelUtils.ArrayComparator();
PriorityQueue<Double[]> queue = new PriorityQueue<>(vec.rows(), comparator);
for (int j = 0; j < vec.length(); j++) {
final Double[] pair = new Double[] {vec.getDouble(j), (double) j};
if (queue.size() < N) {
queue.add(pair);
} else {
Double[] head = queue.peek();
if (comparator.compare(pair, head) > 0) {
queue.poll();
queue.add(pair);
}
}
}
List<Double> lowToHighSimLst = new ArrayList<>();
while (!queue.isEmpty()) {
double ind = queue.poll()[1];
lowToHighSimLst.add(ind);
}
return Lists.reverse(lowToHighSimLst);
}
示例4: clipMatrix
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
* Clip the values within a matrix to -5 to 5 range, to avoid exploding gradients
* @param matrix
* @return
*/
private static INDArray clipMatrix(INDArray matrix) {
NdIndexIterator iter = new NdIndexIterator(matrix.shape());
while (iter.hasNext()) {
int[] nextIndex = iter.next();
double nextVal = matrix.getDouble(nextIndex);
if (nextVal < -5) {
nextVal = -5;
}
if (nextVal > 5) {
nextVal = 5;
}
matrix.putScalar(nextIndex, nextVal);
}
return matrix;
}
示例5: run
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
public Result run(KieMLContainer kc, Model model, Input input) {
Result prediction = null;
List<ModelParam> params = model.getParams();
if(params == null) {
throw new IllegalArgumentException("Parameters to configure the input parsing are required!!");
}
String transformerName = ParamsUtil.getRequiredStringParam(params, "transformerName");
Transformer transformer = TransformerFactory.get(transformerName);
INDArray image = transformer.transform(params, input);
InputStream isModel = kc.getModelBinInputStream(model);
INDArray output = getOutput(isModel, image);
prediction = new Result();
prediction.setText(output.toString());
prediction.setPredictions(new HashMap<>());
for (int i = 0; i < output.columns(); i++) {
if(output.getDouble(i) == 0d) {
continue;
}
prediction.getPredictions().put(model.getLabels().get(i), output.getDouble(i));
}
return prediction;
}
示例6: doPredict
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
protected Object doPredict(List<String> line) {
try {
ListStringSplit input = new ListStringSplit(Collections.singletonList(line));
ListStringRecordReader rr = new ListStringRecordReader();
rr.initialize(input);
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1);
DataSet ds = iterator.next();
INDArray prediction = model.output(ds.getFeatures());
DataType outputType = types.get(this.output);
switch (outputType) {
case _float : return prediction.getDouble(0);
case _class: {
int numClasses = 2;
double max = 0;
int maxIndex = -1;
for (int i=0;i<numClasses;i++) {
if (prediction.getDouble(i) > max) {maxIndex = i; max = prediction.getDouble(i);}
}
return maxIndex;
// return prediction.getInt(0,1); // numberOfClasses
}
default: throw new IllegalArgumentException("Output type not yet supported "+outputType);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
示例7: check
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
private void check(BufferedImage image) throws Exception
{
ImageIO.write(image, "png", new File("tmp.png")); //saves the image to the tmp.png file
ImageRecordReader reader = new ImageRecordReader(150, 150, 3);
reader.initialize(new FileSplit(new File("tmp.png"))); //reads the tmp.png file
DataSetIterator dataIter = new RecordReaderDataSetIterator(reader, 1);
while (dataIter.hasNext())
{
//Normalize the data from the file
DataNormalization normalization = new NormalizerMinMaxScaler();
DataSet set = dataIter.next();
normalization.fit(set);
normalization.transform(set);
INDArray array = MainGUI.model.output(set.getFeatures(), false); //send the data to the model and get the results
//Process the results and print them in an understandable format (percentage scores)
String txt = "";
DecimalFormat df = new DecimalFormat("#.00");
for (int i = 0; i < array.length(); i++)
{
txt += MainGUI.labels.get(i) + ": " + (array.getDouble(i)*100 < 1 ? "0" : "") + df.format((array.getDouble(i)*100)) + "%\n";
}
probabilityArea.setText(txt);
}
reader.close();
}
示例8: computeLineIntegral
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
* See <a href="http://www.wolframalpha.com/input/?i=integral+of+ln((A*t%2BB)%5E2+%2B+(C*t%2BD)%5E2)">
* this equation.</a>
*
* @param lowerBound The starting point of the line along which to calculate the unscaled value
* of the parametrized line integral.
* @param upperBound The final point of the line along which to calculate the unscaled value
* of the parametrized line integral.
* @return A {@link UnaryOperator} representing the indefinite integral of the parametrized line
* integral.
*/
@Override
public UnaryOperator<Double> computeLineIntegral(
final INDArray lowerBound,
final INDArray upperBound
) {
return input -> {
// compute parametric coefficients
final double A = upperBound.getDouble(0, 0) - lowerBound.getDouble(0, 0);
final double B = lowerBound.getDouble(0, 0);
final double C = upperBound.getDouble(1, 0) - lowerBound.getDouble(1, 0);
final double D = lowerBound.getDouble(1, 0);
// compute denominator
final double denominator = Math.pow(A, 2) + Math.pow(C, 2);
// compute numerator terms
final double logarithmicTerm =
(Math.pow(A, 2) * input + A * B + C * (C * input + D))
* Math.log(Math.pow(input, 2) * (Math.pow(A, 2) + Math.pow(C, 2))
+ 2 * A * B * input
+ Math.pow(B, 2)
+ 2 * C * D * input
+ Math.pow(D, 2));
final double inverseTangentTerm =
2 * (A * D - B * C)
* Math.atan((B * C - A * D) / (Math.pow(A, 2) * input + A * B + C * (C * input + D)));
final double linearTerm = -2 * input * (Math.pow(A, 2) + Math.pow(C, 2));
return (logarithmicTerm + inverseTangentTerm + linearTerm) / denominator;
};
}
示例9: train
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
public int train(INDArray x, INDArray t, double learningRate) {
int classified = 0;
double c = x.mmul(w.transpose()).getDouble(0) * t.getDouble(0);
if (c > 0) { classified = 1; }
else { w.addi(x.transpose().mul(t).mul(learningRate).transpose()); }
return classified;
}
示例10: axpy
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
public double[] axpy(double[] x, double[] y) {
// Nd4j.getBlasWrapper().level2().gemv() crashes.
// Use gemm for now.
int m = nrows();
int n = ncols();
INDArray ndx = Nd4j.create(x, new int[]{n, 1});
INDArray ndy = Nd4j.gemm(A, ndx, false, false);
for (int i = 0; i < m; i++) {
y[i] += ndy.getDouble(i);
}
return y;
}
示例11: atx
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
public double[] atx(double[] x, double[] y) {
// Nd4j.getBlasWrapper().level2().gemv() crashes.
// Use gemm for now.
int m = nrows();
int n = ncols();
INDArray ndx = Nd4j.create(x, new int[]{m, 1});
INDArray ndy = Nd4j.gemm(A, ndx, true, false);
for (int i = 0; i < n; i++) {
y[i] = ndy.getDouble(i);
}
return y;
}
示例12: atxpy
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
public double[] atxpy(double[] x, double[] y) {
// Nd4j.getBlasWrapper().level2().gemv() crashes.
// Use gemm for now.
int m = nrows();
int n = ncols();
INDArray ndx = Nd4j.create(x, new int[]{m, 1});
INDArray ndy = Nd4j.gemm(A, ndx, true, false);
for (int i = 0; i < n; i++) {
y[i] += ndy.getDouble(i);
}
return y;
}
示例13: getShape
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
* Report the shape name indicated in the labels vector.
*
* @param labels the labels vector (1.0 for a shape, 0.0 for the others)
* @return the shape name
*/
private OmrShape getShape (INDArray labels)
{
for (int c = 0; c < numClasses; c++) {
double val = labels.getDouble(c);
if (val != 0) {
return OmrShape.values()[c];
}
}
return null;
}
示例14: storeDims
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
* Store the mean/std values for width/height of each (populated) omr shape.
*
* @throws IOException on IO error
*/
private void storeDims ()
throws IOException
{
INDArray dimStats = Nd4j.zeros(4, SHAPE_COUNT);
logger.info("Symbol dimensions for populated shapes:");
for (Entry<OmrShape, DistributionStats.Builder> entry : dimMap.entrySet()) {
OmrShape shape = entry.getKey();
DistributionStats.Builder builder = entry.getValue();
DistributionStats stats = builder.build();
INDArray means = stats.getMean();
INDArray stds = stats.getStd();
int index = shape.ordinal();
double meanWidth = means.getDouble(0);
double stdWidth = stds.getDouble(0);
double meanHeight = means.getDouble(1);
double stdHeight = stds.getDouble(1);
dimStats.putScalar(new int[]{0, index}, meanWidth);
dimStats.putScalar(new int[]{1, index}, stdWidth);
dimStats.putScalar(new int[]{2, index}, meanHeight);
dimStats.putScalar(new int[]{3, index}, stdHeight);
logger.info(
String.format(
"%27s width{mean:%.2f std:%.2f} height{mean:%.2f std:%.2f}",
shape,
meanWidth,
stdWidth,
meanHeight,
stdHeight));
}
Nd4j.saveBinary(dimStats, DIMS_PATH.toFile());
}
示例15: isTransitive
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
private boolean isTransitive(double[] ilpSolution, int n) {
double[][] adjacencyMatrix = new double[n][n];
int k = 0;
for (int i = 0; i < n; i++) {
adjacencyMatrix[i][i] = 1;
for (int j = i + 1; j < n; j++) {
adjacencyMatrix[i][j] = ilpSolution[k];
adjacencyMatrix[j][i] = ilpSolution[k];
k++;
}
}
INDArray m = new NDArray(adjacencyMatrix);
INDArray m2 = m.mmul(m);
System.out.println(m);
for (int i = 0; i < m.rows(); i++) {
for (int j = 0; j < m.columns(); j++) {
if (m2.getDouble(i, j) > 0 && m.getDouble(i, j) == 0) {
System.out.println(i + " " + j + " " + m2.getDouble(i, j) + " " + m.getDouble(i, j));
return false;
}
}
}
return true;
}