当前位置: 首页>>代码示例>>Java>>正文


Java INDArray.mmul方法代码示例

本文整理汇总了Java中org.nd4j.linalg.api.ndarray.INDArray.mmul方法的典型用法代码示例。如果您正苦于以下问题:Java INDArray.mmul方法的具体用法?Java INDArray.mmul怎么用?Java INDArray.mmul使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.api.ndarray.INDArray的用法示例。


在下文中一共展示了INDArray.mmul方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: nd4JExample

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
public void nd4JExample() {
    double[] A = {
        0.1950, 0.0311,
        0.3588, 0.2203,
        0.1716, 0.5931,
        0.2105, 0.3242};

    double[] B = {
        0.0502, 0.9823, 0.9472,
        0.5732, 0.2694, 0.916};

    
    INDArray aINDArray = Nd4j.create(A,new int[]{4,2},'c');
    INDArray bINDArray = Nd4j.create(B,new int[]{2,3},'c');
    
    INDArray cINDArray;
    cINDArray = aINDArray.mmul(bINDArray);
    for(int i=0; i<cINDArray.rows(); i++) {
        System.out.println(cINDArray.getRow(i));
    }
}
 
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:22,代码来源:MathExamples.java

示例2: apply

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * Computes the filtered {@link GaussianDistribution} (mean and covariance) using the supplied
 * state-, measurement-, and control-transition matrices, and the supplied control input, process
 * noise, observation noise, and current state.
 *
 * @param stateTransitionMatrix       The state-transition matrix used in projecting the {@link
 *                                    GaussianDistribution}.
 * @param measurementTransitionMatrix The measurement-transition matrix used in computing the
 *                                    measurement based on the supplied state.
 * @param controlTransitionMatrix     The control-transition matrix used in computing the effect
 *                                    of the control input on the resultant {@link
 *                                    GaussianDistribution}.
 * @param controlInputVector          The control input.
 * @param processCovariance           The process variance matrix. This matrix determines the
 *                                    impact of the state transition matrix on the filter and can
 *                                    be seen as an indication of how trustworthy the state
 *                                    transition model is. Lower values in the process covariance
 *                                    matrix result in the filtered state tending towards the
 *                                    projected state using the state transition model. Higher
 *                                    values result in the filtered state tending towards the
 *                                    measurements.
 * @param measurement                 The {@link GaussianDistribution} of the measurement, where
 *                                    the mean is equal to the measurement vector and the
 *                                    covariance is equal to the measurement variance matrix.
 * @param previousState               The {@link GaussianDistribution} of the latest state.
 * @return The filtered state.
 */
public Distribution apply(
    final INDArray stateTransitionMatrix,
    final INDArray measurementTransitionMatrix,
    final INDArray controlTransitionMatrix,
    final INDArray controlInputVector,
    final INDArray processCovariance,
    final Distribution measurement,
    final Distribution previousState
) {
  final INDArray projectedState = stateTransitionMatrix.mmul(previousState.getMean())
      .add(controlTransitionMatrix.mmul(controlInputVector));

  final INDArray projectedErrorCovariance =
      stateTransitionMatrix
          .mmul(previousState.getCovariance().mmul(stateTransitionMatrix.transpose()))
          .add(processCovariance);

  final INDArray kalmanGain = projectedErrorCovariance
      .mmul(measurementTransitionMatrix.transpose()
          .mmul(InvertMatrix.invert(measurementTransitionMatrix
              .mmul(projectedErrorCovariance
                  .mmul(measurementTransitionMatrix.transpose()))
              .add(measurement.getCovariance()), false)));

  final INDArray estimatedState = projectedState
      .add(kalmanGain
          .mmul(measurement.getMean().sub(measurementTransitionMatrix.mmul(projectedState))));

  final INDArray estimatedErrorCovariance =
      (Nd4j.eye(estimatedState.rows()).sub(kalmanGain.mmul(measurementTransitionMatrix)))
          .mmul(projectedErrorCovariance);

  return new SimpleDistribution(estimatedState, estimatedErrorCovariance);
}
 
开发者ID:delta-leonis,项目名称:algieba,代码行数:62,代码来源:KalmanFilter.java

示例3: apply

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * Computes the filtered {@link Distribution} (mean and covariance) using the supplied state- and
 * measurement-transition functions, their respective Jacobians and the supplied control input,
 * process noise, observation noise, and current state.
 *
 * @param stateTransition               A {@link BiFunction} which takes as its first argument the
 *                                      previous state, and the control input as its second
 *                                      argument. The result is the next state.
 * @param measurementTransition         A {@link Function} which takes as its first argument the
 *                                      current state and returns the measurement.
 * @param stateTransitionJacobian       The Jacobian representing the state-transition.
 * @param measurementTransitionJacobian The Jacobian representing the measurement-transition.
 * @param controlInput                  The control input.
 * @param processCovariance             the process covariance
 * @param observationNoise              The {@link Distribution} of observation noise
 * @param state                         The {@link Distribution} of the latest state.
 * @return The filtered state.
 */
public Distribution apply(
    final BiFunction<INDArray, INDArray, INDArray> stateTransition,
    final Function<INDArray, INDArray> measurementTransition,
    final INDArray stateTransitionJacobian,
    final INDArray measurementTransitionJacobian,
    final INDArray controlInput,
    final INDArray processCovariance,
    final Distribution observationNoise,
    final Distribution state
) {
  final INDArray projectedState = stateTransition.apply(state.getMean(), controlInput);

  final INDArray projectedErrorCovariance = stateTransitionJacobian
      .mmul(state.getMean()
          .mmul(stateTransitionJacobian.transpose()))
      .add(processCovariance);

  final INDArray kalmanGain = projectedErrorCovariance
      .mmul(measurementTransitionJacobian.transpose()
          .mmul(InvertMatrix.invert(measurementTransitionJacobian
              .mmul(projectedErrorCovariance
                  .mmul(measurementTransitionJacobian.transpose()))
              .add(observationNoise.getCovariance()), false)));

  final INDArray estimatedState = projectedState
      .add(kalmanGain
          .mmul(measurementTransition.apply(state.getMean())
              .sub(measurementTransition.apply(projectedState))));

  final INDArray estimatedErrorCovariance = Nd4j.eye(estimatedState.rows())
      .sub(kalmanGain.
          mul(measurementTransitionJacobian))
      .mmul(projectedErrorCovariance);

  return new SimpleDistribution(estimatedState, estimatedErrorCovariance);
}
 
开发者ID:delta-leonis,项目名称:algieba,代码行数:55,代码来源:ExtendedKalmanFilter.java

示例4: isTransitive

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
private boolean isTransitive(double[] ilpSolution, int n) {

		double[][] adjacencyMatrix = new double[n][n];
		int k = 0;
		for (int i = 0; i < n; i++) {
			adjacencyMatrix[i][i] = 1;
			for (int j = i + 1; j < n; j++) {
				adjacencyMatrix[i][j] = ilpSolution[k];
				adjacencyMatrix[j][i] = ilpSolution[k];
				k++;
			}
		}

		INDArray m = new NDArray(adjacencyMatrix);
		INDArray m2 = m.mmul(m);

		System.out.println(m);

		for (int i = 0; i < m.rows(); i++) {
			for (int j = 0; j < m.columns(); j++) {
				if (m2.getDouble(i, j) > 0 && m.getDouble(i, j) == 0) {
					System.out.println(i + " " + j + " " + m2.getDouble(i, j) + " " + m.getDouble(i, j));
					return false;
				}
			}
		}

		return true;
	}
 
开发者ID:UKPLab,项目名称:ijcnlp2017-cmaps,代码行数:30,代码来源:ILPClusterer_Cplex.java


注:本文中的org.nd4j.linalg.api.ndarray.INDArray.mmul方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。