本文整理汇总了Java中org.ejml.simple.SimpleMatrix类的典型用法代码示例。如果您正苦于以下问题:Java SimpleMatrix类的具体用法?Java SimpleMatrix怎么用?Java SimpleMatrix使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。
SimpleMatrix类属于org.ejml.simple包,在下文中一共展示了SimpleMatrix类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: MatrixTreeTheorem
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
public MatrixTreeTheorem(double[][] weight) {
W = weight;
int n = W.length - 1; // #. of nodes except <root-node>
Q = new SimpleMatrix(n, n);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
int x = i - 1, y = j - 1;
if (i == 1) {
Q.set(x, y, W[0][j]);
} else {
if (i == j) {
double v = 0;
for (int k = 1; k <= n; k++) {
v += W[k][j];
}
Q.set(x, y, v);
} else {
Q.set(x, y, -W[i][j]);
}
}
}
} // filling the laplacian matrix.
Z = Q.determinant();
}
示例2: printMatrix
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
private static String printMatrix(SimpleMatrix mat) {
StringBuilder buffer = new StringBuilder("[");
int m = mat.numRows();
int n = mat.numCols();
for (int i = 0; i < m; i++) {
buffer.append('[');
for (int j = 0; j < n; j++) {
buffer.append(numToString(mat.get(i, j)));
if (j != n - 1) buffer.append(',');
}
buffer.append(']');
}
buffer.append(']');
return buffer.toString();
}
示例3: matrixToString
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
public static String matrixToString(SimpleMatrix matrix, Logic logic) throws SyntaxException {
int rows = matrix.numRows();
int columns = matrix.numCols();
String input = "[";
for(int i = 0; i < rows; i++) {
input += "[";
for(int j = 0; j < columns; j++) {
input += logic.evaluate(Double.toString(matrix.get(i, j))) + ",";
}
// Remove trailing ,
input = input.substring(0, input.length() - 1);
input += "]";
}
input += "]";
return input;
}
示例4: actualizarParametrosMomentum
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
/**
* W(+1) = W - stepSize * [(Dw / m)+ W * reg] <br>
* B(+1) = B - stepSize * (Db / m) <br>
*
* @param m cantidad de datos
*/
protected void actualizarParametrosMomentum(double m) {
for (int i = 0; i < net.getLayers().size(); i++) {
Layer layer = net.getLayers().get(i);
SimpleMatrix W = layer.getW();
SimpleMatrix B = layer.getB();
SimpleMatrix reg = W.scale(regularization);
SimpleMatrix vW = deltasWprev.get(i).scale(momentum).minus(deltasW.get(i).divide(m).plus(reg).scale(learningRate));
W = W.plus(vW);
SimpleMatrix vB = deltasBprev.get(i).scale(momentum).minus(deltasB.get(i).divide(m).scale(learningRate));
B = B.plus(vB);
layer.setW(W);
layer.setB(B);
net.getLayers().set(i, layer);
deltasBprev.set(i, vB);
deltasWprev.set(i, vW);
}
}
示例5: testDerivative
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
/**
* Test of derivative method, of class Relu.
*/
@Ignore
@Test
public void testDerivative() {
System.out.println("output");
SimpleMatrix z = new SimpleMatrix(1, 5, true, -1, 1, 0, .1, -.1);
Relu instance = new Relu();
double[] expResult = {0, 1, 0, 1, 0};
SimpleMatrix result = instance.derivative(z);
assertArrayEquals(expResult, result.getMatrix().getData(), 0);
}
示例6: eval
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
public static SimpleMatrix eval(SimpleMatrix matrix){
SimpleMatrix aux = new SimpleMatrix(matrix.numRows(), matrix.numCols());
double max;
int pos;
for (int i = 0; i < aux.numRows(); i++) {
SimpleMatrix row = matrix.extractVector(true, i);
max = row.get(0);
pos = 0;
//inicializamos en 1 ya que el 0 ya fue tomado
for (int j = 1; j < row.numCols(); j++) {
if (max < row.get(j)) {
max = row.get(j);
//guardamos la posición del mas grande
pos = j;
}
}
//guardamos en la fila y columna un 1
aux.set(i, pos, 1);
}
return aux;
}
示例7: testOutput
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
/**
* Test of output method, of class Softmax.
*/
@Test
public void testOutput() {
TransferFunction softmax = FunctionFactory.getFunction("softmax");
SimpleMatrix z = new SimpleMatrix(5, 1, true,
-3, -1, 0, 1, 3
);
double[] salida = softmax.output(z).getMatrix().getData();
double[] esperado = {
0.002055492,
0.015188145,
0.04128566,
0.112226059,
0.829244644};
//softmax.output(z).print();
assertArrayEquals(esperado, salida, 0.000000001);
}
示例8: addScalar
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
static SimpleMatrix addScalar(SimpleMatrix mat, double scalar) {
SimpleMatrix temp = mat.copy();
int M = mat.numRows();
int N = mat.numCols();
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++)
temp.set(i, j, mat.get(i, j) + scalar);
}
return temp;
}
示例9: evaluateMatrices
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
String evaluateMatrices(AdvancedDisplay display) throws SyntaxException {
try {
SimpleMatrix matrix = null;
boolean add = false;
boolean multiply = false;
for(int i = 0; i < display.getChildCount(); i++) {
View child = display.getChildAt(i);
if(child instanceof MatrixView) {
if(!add && !multiply) {
matrix = ((MatrixView) child).getSimpleMatrix();
}
else if(add) {
add = false;
if(matrix == null) throw new SyntaxException();
matrix = matrix.plus(((MatrixView) child).getSimpleMatrix());
}
else if(multiply) {
multiply = false;
if(matrix == null) throw new SyntaxException();
matrix = matrix.mult(((MatrixView) child).getSimpleMatrix());
}
}
else if(child instanceof MatrixTransposeView) {
if(matrix == null) throw new SyntaxException();
matrix = matrix.transpose();
}
else if(child instanceof MatrixInverseView) {
if(matrix == null) throw new SyntaxException();
matrix = matrix.invert();
}
else {
String text = child.toString();
if(text.length() > 1) throw new SyntaxException();
else if(text.length() == 0) continue;
if(text.startsWith(String.valueOf(Logic.MUL))) multiply = true;
else if(text.startsWith(String.valueOf(Logic.PLUS))) add = true;
else throw new SyntaxException();
}
}
return logic.mBaseModule.updateTextToNewMode(MatrixView.matrixToString(matrix, logic), Mode.DECIMAL, logic.mBaseModule.getMode());
}
catch(Exception e) {
throw new SyntaxException();
}
}
示例10: testDerivative
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
/**
* Test of derivative method, of class Purelim.
*/
@Test
public void testDerivative() {
TransferFunction purelim = FunctionFactory.getFunction("purelim");
SimpleMatrix a = new SimpleMatrix(1, 5, true, -3, -1, 0, 1, 3);
double[] salida = purelim.derivative(a).getMatrix().getData();
double[] esperado = {1, 1, 1, 1, 1};
assertArrayEquals(esperado, salida, 0.000000001);
}
示例11: bachDatos
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
/**
* Aquí indicamos los datos que irán al entrenamiento, si en la división de
* datos hay resto, estos se ubicarán en la última parte
*
* @param part
*/
protected void bachDatos(int part) {
if (part < (cantidadBach - 1)) {
inputBach = this.input.extractMatrix(batchSize * part, batchSize * part + batchSize, 0, SimpleMatrix.END);
outputBach = this.output.extractMatrix(batchSize * part, batchSize * part + batchSize, 0, SimpleMatrix.END);
} else {//como es la última parte tomamos los datos restantes.
inputBach = this.input.extractMatrix(batchSize * part, SimpleMatrix.END, 0, SimpleMatrix.END);
outputBach = this.output.extractMatrix(batchSize * part, SimpleMatrix.END, 0, SimpleMatrix.END);
}
}
示例12: testOutput_doubleArr
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
/**
* Test of output method, of class Layer.
*/
@Test
public void testOutput_doubleArr() {
System.out.println("output");
SimpleMatrix w = new SimpleMatrix(2, 3, true, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6);
SimpleMatrix b = new SimpleMatrix(2, 1, true, 0.1, 0.2);
double[] input = {0.2, 0.4, 0.6};
Layer instance = new Layer(w, b, TransferFunction.LOGSIG);
double[] expResult = {0.5938731029341, 0.6984652160025};
double[] result = instance.output(input).getMatrix().getData();
assertArrayEquals(expResult, result, 0.0000000000001);
}
示例13: inicializar
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
/**
* iniciamos los deltas según la cantidad de pesos y capas que hay en la red
* y colocamos todos los deltas en 0.
*/
public void inicializar() {
iteracion = 0;
deltasB.clear();
deltasW.clear();
deriv.clear();
gradB.clear();
gradW.clear();
deltasWprev.clear();
deltasBprev.clear();
for (int i = 0; i < net.getLayers().size(); i++) {
//agregamos los deltas de todos los pesos
deltasW.add(new SimpleMatrix(net.getLayers().get(i).getW()));
deltasWprev.add(new SimpleMatrix(net.getLayers().get(i).getW()));
//agregamos todos los deltas de los bias
deltasB.add(new SimpleMatrix(net.getLayers().get(i).getB()));
deltasBprev.add(new SimpleMatrix(net.getLayers().get(i).getB()));
//agregamos las matrices de derivadas 1 por neurona
deriv.add(new SimpleMatrix());
gradW.add(new SimpleMatrix());
gradB.add(new SimpleMatrix());
}
cost.clear();
for (int i = 0; i < deltasW.size(); i++) {
deltasWprev.get(i).zero();
deltasBprev.get(i).zero();
}
deltasZero();//inicializamos los deltas en 0
}
示例14: calcularGradientes
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
protected void calcularGradientes() {
deltasZero();//quitar?
ND = inputBach.numRows();
SimpleMatrix in;
SimpleMatrix yObs;
for (int i = 0; i < ND; i++) {//Aquí debemos paralelizar
//extraemos la fila y la ponemos vertical
in = inputBach.extractVector(true, i).transpose();
yObs = outputBach.extractVector(true, i).transpose();
//obtenemos la salida de todas las capas para ganar tiempo
List<SimpleMatrix> outputs = net.outputLayers(in);
//calculamos los delta
SimpleMatrix yCalc = outputs.get(outputs.size() - 1);
derivadaOutputLayers(yCalc, yObs);
derivadaHiddenLayers(outputs);
//calculamos los gradientes
//la primera entrada corresponde a los datos
SimpleMatrix a_t = in.transpose();
for (int j = 0; j < gradW.size(); j++) {
SimpleMatrix d = deriv.get(j);
//calculamos el gradiente
gradW.set(j, d.mult(a_t));
gradB.set(j, d);
//preparamos la entrada para la siguiente capa
a_t = outputs.get(j).transpose();
}
for (int j = 0; j < gradW.size(); j++) {
//agregamos el delta
deltasW.set(j, deltasW.get(j).plus(gradW.get(j)));
deltasB.set(j, deltasB.get(j).plus(gradB.get(j)));
}
}
}
示例15: lossRegularization
import org.ejml.simple.SimpleMatrix; //导入依赖的package包/类
protected double lossRegularization() {
double loss = 0;
if (regularization > 0) {
SimpleMatrix weights = net.getParamsW();
loss += weights.mult(weights.transpose()).scale(regularization).get(0);
}
return loss;
}