本文整理汇总了Java中org.nd4j.linalg.api.ndarray.INDArray.add方法的典型用法代码示例。如果您正苦于以下问题:Java INDArray.add方法的具体用法?Java INDArray.add怎么用?Java INDArray.add使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.api.ndarray.INDArray
的用法示例。
在下文中一共展示了INDArray.add方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: apply
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
* Computes the filtered {@link GaussianDistribution} (mean and covariance) using the supplied
* state-, measurement-, and control-transition matrices, and the supplied control input, process
* noise, observation noise, and current state.
*
* @param stateTransitionMatrix The state-transition matrix used in projecting the {@link
* GaussianDistribution}.
* @param measurementTransitionMatrix The measurement-transition matrix used in computing the
* measurement based on the supplied state.
* @param controlTransitionMatrix The control-transition matrix used in computing the effect
* of the control input on the resultant {@link
* GaussianDistribution}.
* @param controlInputVector The control input.
* @param processCovariance The process variance matrix. This matrix determines the
* impact of the state transition matrix on the filter and can
* be seen as an indication of how trustworthy the state
* transition model is. Lower values in the process covariance
* matrix result in the filtered state tending towards the
* projected state using the state transition model. Higher
* values result in the filtered state tending towards the
* measurements.
* @param measurement The {@link GaussianDistribution} of the measurement, where
* the mean is equal to the measurement vector and the
* covariance is equal to the measurement variance matrix.
* @param previousState The {@link GaussianDistribution} of the latest state.
* @return The filtered state.
*/
public Distribution apply(
final INDArray stateTransitionMatrix,
final INDArray measurementTransitionMatrix,
final INDArray controlTransitionMatrix,
final INDArray controlInputVector,
final INDArray processCovariance,
final Distribution measurement,
final Distribution previousState
) {
final INDArray projectedState = stateTransitionMatrix.mmul(previousState.getMean())
.add(controlTransitionMatrix.mmul(controlInputVector));
final INDArray projectedErrorCovariance =
stateTransitionMatrix
.mmul(previousState.getCovariance().mmul(stateTransitionMatrix.transpose()))
.add(processCovariance);
final INDArray kalmanGain = projectedErrorCovariance
.mmul(measurementTransitionMatrix.transpose()
.mmul(InvertMatrix.invert(measurementTransitionMatrix
.mmul(projectedErrorCovariance
.mmul(measurementTransitionMatrix.transpose()))
.add(measurement.getCovariance()), false)));
final INDArray estimatedState = projectedState
.add(kalmanGain
.mmul(measurement.getMean().sub(measurementTransitionMatrix.mmul(projectedState))));
final INDArray estimatedErrorCovariance =
(Nd4j.eye(estimatedState.rows()).sub(kalmanGain.mmul(measurementTransitionMatrix)))
.mmul(projectedErrorCovariance);
return new SimpleDistribution(estimatedState, estimatedErrorCovariance);
}
示例2: apply
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
* Computes the filtered {@link Distribution} (mean and covariance) using the supplied state- and
* measurement-transition functions, their respective Jacobians and the supplied control input,
* process noise, observation noise, and current state.
*
* @param stateTransition A {@link BiFunction} which takes as its first argument the
* previous state, and the control input as its second
* argument. The result is the next state.
* @param measurementTransition A {@link Function} which takes as its first argument the
* current state and returns the measurement.
* @param stateTransitionJacobian The Jacobian representing the state-transition.
* @param measurementTransitionJacobian The Jacobian representing the measurement-transition.
* @param controlInput The control input.
* @param processCovariance the process covariance
* @param observationNoise The {@link Distribution} of observation noise
* @param state The {@link Distribution} of the latest state.
* @return The filtered state.
*/
public Distribution apply(
final BiFunction<INDArray, INDArray, INDArray> stateTransition,
final Function<INDArray, INDArray> measurementTransition,
final INDArray stateTransitionJacobian,
final INDArray measurementTransitionJacobian,
final INDArray controlInput,
final INDArray processCovariance,
final Distribution observationNoise,
final Distribution state
) {
final INDArray projectedState = stateTransition.apply(state.getMean(), controlInput);
final INDArray projectedErrorCovariance = stateTransitionJacobian
.mmul(state.getMean()
.mmul(stateTransitionJacobian.transpose()))
.add(processCovariance);
final INDArray kalmanGain = projectedErrorCovariance
.mmul(measurementTransitionJacobian.transpose()
.mmul(InvertMatrix.invert(measurementTransitionJacobian
.mmul(projectedErrorCovariance
.mmul(measurementTransitionJacobian.transpose()))
.add(observationNoise.getCovariance()), false)));
final INDArray estimatedState = projectedState
.add(kalmanGain
.mmul(measurementTransition.apply(state.getMean())
.sub(measurementTransition.apply(projectedState))));
final INDArray estimatedErrorCovariance = Nd4j.eye(estimatedState.rows())
.sub(kalmanGain.
mul(measurementTransitionJacobian))
.mmul(projectedErrorCovariance);
return new SimpleDistribution(estimatedState, estimatedErrorCovariance);
}
示例3: aggregate
import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
public SeldonMessage aggregate(List<SeldonMessage> outputs, PredictiveUnitState state){
if (outputs.size()==0){
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received no inputs"));
}
int[] shape = PredictorUtils.getShape(outputs.get(0).getData());
if (shape == null){
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner cannot extract data shape"));
}
if (shape.length!=2){
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received data that is not 2 dimensional"));
}
INDArray currentSum = Nd4j.zeros(shape[0],shape[1]);
SeldonMessage.Builder respBuilder = SeldonMessage.newBuilder();
for (Iterator<SeldonMessage> i = outputs.iterator(); i.hasNext();)
{
DefaultData inputData = i.next().getData();
int[] inputShape = PredictorUtils.getShape(inputData);
if (inputShape == null){
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner cannot extract data shape"));
}
if (inputShape.length!=2){
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Combiner received data that is not 2 dimensional"));
}
if (inputShape[0] != shape[0]){
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Expected batch length %d but found %d",shape[0],inputShape[0]));
}
if (inputShape[1] != shape[1]){
throw new APIException(APIException.ApiExceptionType.ENGINE_INVALID_COMBINER_RESPONSE, String.format("Expected batch length %d but found %d",shape[1],inputShape[1]));
}
INDArray inputArr = PredictorUtils.getINDArray(inputData);
currentSum = currentSum.add(inputArr);
}
currentSum = currentSum.div((float)outputs.size());
DefaultData newData = PredictorUtils.updateData(outputs.get(0).getData(), currentSum);
respBuilder.setData(newData);
respBuilder.setMeta(outputs.get(0).getMeta());
respBuilder.setStatus(outputs.get(0).getStatus());
return respBuilder.build();
}