当前位置: 首页>>代码示例>>Java>>正文


Java INDArray.getDouble方法代码示例

本文整理汇总了Java中org.nd4j.linalg.api.ndarray.INDArray.getDouble方法的典型用法代码示例。如果您正苦于以下问题:Java INDArray.getDouble方法的具体用法?Java INDArray.getDouble怎么用?Java INDArray.getDouble使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.api.ndarray.INDArray的用法示例。


在下文中一共展示了INDArray.getDouble方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: validate

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
public void validate() {
    int p = 0;
    int n = 0;
    for (TrainingData td : trainingDataSet) {
        INDArray a = feedForward(td.input);
        double r0 = a.getDouble(0);
        double r1 = a.getDouble(1);
        double r2 = a.getDouble(2);
        int m = -1;
        if (r0 > r1 && r0 > r2) m = 0;
        if (r1 > r0 && r1 > r2) m = 1;
        if (r2 > r0 && r2 > r1) m = 2;
        if (td.output.getDouble(m) == 1) {
            p++;
        } else {
            n++;
        }
    }
    System.out.println("Positive: " + p);
    System.out.println("Negative: " + n);
}
 
开发者ID:apuder,项目名称:ActivityMonitor,代码行数:22,代码来源:SGD.java

示例2: predict

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
public INDArray predict(INDArray x) {
    INDArray y = output(x); // activate input data through learned networks
    INDArray out = Nd4j.create(new double[x.rows() * nOut], new int[] { x.rows(), nOut });
    for (int i = 0; i < x.rows(); i++) {
        int argmax = -1;
        double max = Double.MIN_VALUE;
        for (int j = 0; j < nOut; j++) {
            if (max < y.getDouble(i, j)) {
                argmax = j;
                max = y.getDouble(i, j);
            }
        }
        out.put(i, argmax, Nd4j.scalar(1.0));
    }
    return out;
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:17,代码来源:OutputLayer.java

示例3: getTopN

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * Get top N elements
 *
 * @param vec the vec to extract the top elements from
 * @param N the number of elements to extract
 * @return the indices and the sorted top N elements
 */
private List<Double> getTopN(INDArray vec, int N) {
  BasicModelUtils.ArrayComparator comparator = new BasicModelUtils.ArrayComparator();
  PriorityQueue<Double[]> queue = new PriorityQueue<>(vec.rows(), comparator);

  for (int j = 0; j < vec.length(); j++) {
    final Double[] pair = new Double[] {vec.getDouble(j), (double) j};
    if (queue.size() < N) {
      queue.add(pair);
    } else {
      Double[] head = queue.peek();
      if (comparator.compare(pair, head) > 0) {
        queue.poll();
        queue.add(pair);
      }
    }
  }

  List<Double> lowToHighSimLst = new ArrayList<>();

  while (!queue.isEmpty()) {
    double ind = queue.poll()[1];
    lowToHighSimLst.add(ind);
  }
  return Lists.reverse(lowToHighSimLst);
}
 
开发者ID:tteofili,项目名称:par2hier,代码行数:33,代码来源:Par2Hier.java

示例4: clipMatrix

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * Clip the values within a matrix to -5 to 5 range, to avoid exploding gradients
 * @param matrix
 * @return
 */
private static INDArray clipMatrix(INDArray matrix) {
    NdIndexIterator iter = new NdIndexIterator(matrix.shape());
    while (iter.hasNext()) {
        int[] nextIndex = iter.next();
        double nextVal = matrix.getDouble(nextIndex);
        if (nextVal < -5) {
            nextVal = -5;
        }
        if (nextVal > 5) {
            nextVal = 5;
        }
        matrix.putScalar(nextIndex, nextVal);
    }
    return matrix;
}
 
开发者ID:guilherme-pombo,项目名称:JavaRNN,代码行数:21,代码来源:CharRNN.java

示例5: run

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
public Result run(KieMLContainer kc, Model model, Input input) {
	Result prediction = null;
	List<ModelParam> params = model.getParams();
	if(params == null) {
		throw new IllegalArgumentException("Parameters to configure the input parsing are required!!");
	}
	String transformerName = ParamsUtil.getRequiredStringParam(params, "transformerName");
	Transformer transformer = TransformerFactory.get(transformerName);
	INDArray image = transformer.transform(params, input);
	InputStream isModel = kc.getModelBinInputStream(model);
	INDArray output = getOutput(isModel, image);
	prediction = new Result();
	prediction.setText(output.toString());
	prediction.setPredictions(new HashMap<>());
	for (int i = 0; i < output.columns(); i++) {
		if(output.getDouble(i) == 0d) { 
			continue;
		}
		prediction.getPredictions().put(model.getLabels().get(i), output.getDouble(i));
	}
	return prediction;
}
 
开发者ID:jesuino,项目名称:kie-ml,代码行数:23,代码来源:DL4JKieMLProvider.java

示例6: doPredict

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
    protected Object doPredict(List<String> line) {
        try {
            ListStringSplit input = new ListStringSplit(Collections.singletonList(line));
            ListStringRecordReader rr = new ListStringRecordReader();
            rr.initialize(input);
            DataSetIterator iterator = new RecordReaderDataSetIterator(rr, 1);

            DataSet ds = iterator.next();
            INDArray prediction = model.output(ds.getFeatures());

            DataType outputType = types.get(this.output);
            switch (outputType) {
                case _float : return prediction.getDouble(0);
                case _class: {
                    int numClasses = 2;
                    double max = 0;
                    int maxIndex = -1;
                    for (int i=0;i<numClasses;i++) {
                        if (prediction.getDouble(i) > max) {maxIndex = i; max = prediction.getDouble(i);}
                    }
                    return maxIndex;
//                    return prediction.getInt(0,1); // numberOfClasses
                }
                default: throw new IllegalArgumentException("Output type not yet supported "+outputType);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
 
开发者ID:neo4j-contrib,项目名称:neo4j-ml-procedures,代码行数:31,代码来源:DL4JMLModel.java

示例7: check

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
private void check(BufferedImage image) throws Exception
{
    ImageIO.write(image, "png", new File("tmp.png")); //saves the image to the tmp.png file
    ImageRecordReader reader = new ImageRecordReader(150, 150, 3);
    reader.initialize(new FileSplit(new File("tmp.png"))); //reads the tmp.png file
    DataSetIterator dataIter = new RecordReaderDataSetIterator(reader, 1);
    while (dataIter.hasNext())
    {
        //Normalize the data from the file
        DataNormalization normalization = new NormalizerMinMaxScaler();
        DataSet set = dataIter.next();
        normalization.fit(set);
        normalization.transform(set);

        INDArray array = MainGUI.model.output(set.getFeatures(), false); //send the data to the model and get the results

        //Process the results and print them in an understandable format (percentage scores)
        String txt = "";

        DecimalFormat df = new DecimalFormat("#.00");

        for (int i = 0; i < array.length(); i++)
        {
            txt += MainGUI.labels.get(i) + ": " + (array.getDouble(i)*100 < 1 ? "0" : "") + df.format((array.getDouble(i)*100)) + "%\n";
        }

        probabilityArea.setText(txt);
    }

    reader.close();
}
 
开发者ID:maksgraczyk,项目名称:DeepID,代码行数:32,代码来源:Identification.java

示例8: computeLineIntegral

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * See <a href="http://www.wolframalpha.com/input/?i=integral+of+ln((A*t%2BB)%5E2+%2B+(C*t%2BD)%5E2)">
 * this equation.</a>
 *
 * @param lowerBound The starting point of the line along which to calculate the unscaled value
 *                   of the parametrized line integral.
 * @param upperBound The final point of the line along which to calculate the unscaled value
 *                   of the parametrized line integral.
 * @return A {@link UnaryOperator} representing the indefinite integral of the parametrized line
 *     integral.
 */
@Override
public UnaryOperator<Double> computeLineIntegral(
    final INDArray lowerBound,
    final INDArray upperBound
) {
  return input -> {
    // compute parametric coefficients
    final double A = upperBound.getDouble(0, 0) - lowerBound.getDouble(0, 0);
    final double B = lowerBound.getDouble(0, 0);
    final double C = upperBound.getDouble(1, 0) - lowerBound.getDouble(1, 0);
    final double D = lowerBound.getDouble(1, 0);

    // compute denominator
    final double denominator = Math.pow(A, 2) + Math.pow(C, 2);

    // compute numerator terms
    final double logarithmicTerm =
        (Math.pow(A, 2) * input + A * B + C * (C * input + D))
            * Math.log(Math.pow(input, 2) * (Math.pow(A, 2) + Math.pow(C, 2))
            + 2 * A * B * input
            + Math.pow(B, 2)
            + 2 * C * D * input
            + Math.pow(D, 2));
    final double inverseTangentTerm =
        2 * (A * D - B * C)
            * Math.atan((B * C - A * D) / (Math.pow(A, 2) * input + A * B + C * (C * input + D)));
    final double linearTerm = -2 * input * (Math.pow(A, 2) + Math.pow(C, 2));

    return (logarithmicTerm + inverseTangentTerm + linearTerm) / denominator;
  };
}
 
开发者ID:delta-leonis,项目名称:algieba,代码行数:43,代码来源:HydrodynamicPotentialField.java

示例9: train

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
public int train(INDArray x, INDArray t, double learningRate) {
	int classified = 0;
	double c = x.mmul(w.transpose()).getDouble(0) * t.getDouble(0);
	if (c > 0) { classified = 1; }
	else { w.addi(x.transpose().mul(t).mul(learningRate).transpose()); }
	return classified;
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:8,代码来源:Perceptron.java

示例10: axpy

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
public double[] axpy(double[] x, double[] y) {
    // Nd4j.getBlasWrapper().level2().gemv() crashes.
    // Use gemm for now.
    int m = nrows();
    int n = ncols();
    INDArray ndx = Nd4j.create(x, new int[]{n, 1});
    INDArray ndy = Nd4j.gemm(A, ndx, false, false);
    for (int i = 0; i < m; i++) {
        y[i] += ndy.getDouble(i);
    }

    return y;
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:15,代码来源:NDMatrix.java

示例11: atx

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
public double[] atx(double[] x, double[] y) {
    // Nd4j.getBlasWrapper().level2().gemv() crashes.
    // Use gemm for now.
    int m = nrows();
    int n = ncols();
    INDArray ndx = Nd4j.create(x, new int[]{m, 1});
    INDArray ndy = Nd4j.gemm(A, ndx, true, false);
    for (int i = 0; i < n; i++) {
        y[i] = ndy.getDouble(i);
    }

    return y;
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:15,代码来源:NDMatrix.java

示例12: atxpy

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
public double[] atxpy(double[] x, double[] y) {
    // Nd4j.getBlasWrapper().level2().gemv() crashes.
    // Use gemm for now.
    int m = nrows();
    int n = ncols();
    INDArray ndx = Nd4j.create(x, new int[]{m, 1});
    INDArray ndy = Nd4j.gemm(A, ndx, true, false);
    for (int i = 0; i < n; i++) {
        y[i] += ndy.getDouble(i);
    }

    return y;
}
 
开发者ID:takun2s,项目名称:smile_1.5.0_java7,代码行数:15,代码来源:NDMatrix.java

示例13: getShape

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * Report the shape name indicated in the labels vector.
 *
 * @param labels the labels vector (1.0 for a shape, 0.0 for the others)
 * @return the shape name
 */
private OmrShape getShape (INDArray labels)
{
    for (int c = 0; c < numClasses; c++) {
        double val = labels.getDouble(c);

        if (val != 0) {
            return OmrShape.values()[c];
        }
    }

    return null;
}
 
开发者ID:Audiveris,项目名称:omr-dataset-tools,代码行数:19,代码来源:SubImages.java

示例14: storeDims

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * Store the mean/std values for width/height of each (populated) omr shape.
 *
 * @throws IOException on IO error
 */
private void storeDims ()
        throws IOException
{
    INDArray dimStats = Nd4j.zeros(4, SHAPE_COUNT);
    logger.info("Symbol dimensions for populated shapes:");

    for (Entry<OmrShape, DistributionStats.Builder> entry : dimMap.entrySet()) {
        OmrShape shape = entry.getKey();
        DistributionStats.Builder builder = entry.getValue();
        DistributionStats stats = builder.build();
        INDArray means = stats.getMean();
        INDArray stds = stats.getStd();
        int index = shape.ordinal();
        double meanWidth = means.getDouble(0);
        double stdWidth = stds.getDouble(0);
        double meanHeight = means.getDouble(1);
        double stdHeight = stds.getDouble(1);
        dimStats.putScalar(new int[]{0, index}, meanWidth);
        dimStats.putScalar(new int[]{1, index}, stdWidth);
        dimStats.putScalar(new int[]{2, index}, meanHeight);
        dimStats.putScalar(new int[]{3, index}, stdHeight);
        logger.info(
                String.format(
                        "%27s width{mean:%.2f std:%.2f} height{mean:%.2f std:%.2f}",
                        shape,
                        meanWidth,
                        stdWidth,
                        meanHeight,
                        stdHeight));
    }

    Nd4j.saveBinary(dimStats, DIMS_PATH.toFile());
}
 
开发者ID:Audiveris,项目名称:omr-dataset-tools,代码行数:39,代码来源:Features.java

示例15: isTransitive

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
private boolean isTransitive(double[] ilpSolution, int n) {

		double[][] adjacencyMatrix = new double[n][n];
		int k = 0;
		for (int i = 0; i < n; i++) {
			adjacencyMatrix[i][i] = 1;
			for (int j = i + 1; j < n; j++) {
				adjacencyMatrix[i][j] = ilpSolution[k];
				adjacencyMatrix[j][i] = ilpSolution[k];
				k++;
			}
		}

		INDArray m = new NDArray(adjacencyMatrix);
		INDArray m2 = m.mmul(m);

		System.out.println(m);

		for (int i = 0; i < m.rows(); i++) {
			for (int j = 0; j < m.columns(); j++) {
				if (m2.getDouble(i, j) > 0 && m.getDouble(i, j) == 0) {
					System.out.println(i + " " + j + " " + m2.getDouble(i, j) + " " + m.getDouble(i, j));
					return false;
				}
			}
		}

		return true;
	}
 
开发者ID:UKPLab,项目名称:ijcnlp2017-cmaps,代码行数:30,代码来源:ILPClusterer_Cplex.java


注:本文中的org.nd4j.linalg.api.ndarray.INDArray.getDouble方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。