本文整理汇总了Java中org.nd4j.linalg.factory.Nd4j.create方法的典型用法代码示例。如果您正苦于以下问题:Java Nd4j.create方法的具体用法?Java Nd4j.create怎么用?Java Nd4j.create使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.factory.Nd4j
的用法示例。
在下文中一共展示了Nd4j.create方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: fetch
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
@Override
public void fetch(int numExamples) {
float[][] featureData = new float[numExamples][0];
float[][] labelData = new float[numExamples][0];
int examplesRead = 0;
for (; examplesRead < numExamples; examplesRead++) {
if (cursor + examplesRead >= m_allFileNames.size()) {
break;
}
Entry<String, String> entry = m_allFileNames.get(cursor + examplesRead);
featureData[examplesRead] = imageFileNameToMnsitFormat(entry.getValue());
labelData[examplesRead] = toLabelArray(entry.getKey());
}
cursor += examplesRead;
INDArray features = Nd4j.create(featureData);
INDArray labels = Nd4j.create(labelData);
curr = new DataSet(features, labels);
}
示例2: nd4JExample
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
public void nd4JExample() {
double[] A = {
0.1950, 0.0311,
0.3588, 0.2203,
0.1716, 0.5931,
0.2105, 0.3242};
double[] B = {
0.0502, 0.9823, 0.9472,
0.5732, 0.2694, 0.916};
INDArray aINDArray = Nd4j.create(A,new int[]{4,2},'c');
INDArray bINDArray = Nd4j.create(B,new int[]{2,3},'c');
INDArray cINDArray;
cINDArray = aINDArray.mmul(bINDArray);
for(int i=0; i<cINDArray.rows(); i++) {
System.out.println(cINDArray.getRow(i));
}
}
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:22,代码来源:MathExamples.java
示例3: readTestData
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
private static List<TrainingData> readTestData(String fn) {
int[] shape = { 3, 1 };
List<TrainingData> trainingDataSet = new ArrayList<>();
try {
CSVReader reader = new CSVReader(new FileReader(fn));
String[] row;
while ((row = reader.readNext()) != null) {
int type = Integer.parseInt(row[0]);
double f1 = Double.parseDouble(row[1]);
double f2 = Double.parseDouble(row[2]);
double f3 = Double.parseDouble(row[3]);
TrainingData trainingData = new TrainingData();
trainingData.input = Nd4j.create(new double[] { f1, f2, f3 }, shape);
trainingData.output = Nd4j.zeros(shape);
trainingData.output.putScalar(type, (double) 1);
trainingDataSet.add(trainingData);
}
} catch (java.io.IOException e) {
}
return trainingDataSet;
}
示例4: preTrain
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
public void preTrain(List<INDArray> X, int minibatchSize, int minibatch_N, int epochs, double learningRate,
double corruptionLevel) {
for (int layer = 0; layer < nLayers; layer++) {
for (int epoch = 0; epoch < epochs; epoch++) {
for (int batch = 0; batch < minibatch_N; batch++) {
INDArray X_ = Nd4j.create(new double[minibatchSize * nIn], new int[] { minibatchSize, nIn });
INDArray prevLayerX_;
// Set input data for current layer
if (layer == 0) {
X_ = X.get(batch);
} else {
prevLayerX_ = X_;
X_ = hiddenLayers[layer - 1].forward(prevLayerX_);
}
daLayers[layer].train(X_, minibatchSize, learningRate, corruptionLevel);
}
}
}
}
示例5: State
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
public State(
final int id,
final double timestamp,
final double x,
final double y,
final double orientation,
final TeamColor teamColor
) {
this(
id,
new SimpleDistribution(Nd4j.create(
new double[]{
timestamp,
x,
y,
orientation
},
new int[]{4, 1}), Nd4j.eye(4)),
teamColor);
}
示例6: predict
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
public INDArray predict(INDArray x) {
INDArray y = output(x); // activate input data through learned networks
INDArray out = Nd4j.create(new double[x.rows() * nOut], new int[] { x.rows(), nOut });
for (int i = 0; i < x.rows(); i++) {
int argmax = -1;
double max = Double.MIN_VALUE;
for (int j = 0; j < nOut; j++) {
if (max < y.getDouble(i, j)) {
argmax = j;
max = y.getDouble(i, j);
}
}
out.put(i, argmax, Nd4j.scalar(1.0));
}
return out;
}
示例7: next
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
@Override
public DataSet next(int num) {
if (exampleStartOffsets.size() == 0) throw new NoSuchElementException();
int actualMiniBatchSize = Math.min(num, exampleStartOffsets.size());
INDArray input = Nd4j.create(new int[] {actualMiniBatchSize, VECTOR_SIZE, exampleLength}, 'f');
INDArray label;
if (category.equals(PriceCategory.ALL)) label = Nd4j.create(new int[] {actualMiniBatchSize, VECTOR_SIZE, exampleLength}, 'f');
else label = Nd4j.create(new int[] {actualMiniBatchSize, predictLength, exampleLength}, 'f');
for (int index = 0; index < actualMiniBatchSize; index++) {
int startIdx = exampleStartOffsets.removeFirst();
int endIdx = startIdx + exampleLength;
StockData curData = train.get(startIdx);
StockData nextData;
for (int i = startIdx; i < endIdx; i++) {
int c = i - startIdx;
input.putScalar(new int[] {index, 0, c}, (curData.getOpen() - minArray[0]) / (maxArray[0] - minArray[0]));
input.putScalar(new int[] {index, 1, c}, (curData.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
input.putScalar(new int[] {index, 2, c}, (curData.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
input.putScalar(new int[] {index, 3, c}, (curData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
input.putScalar(new int[] {index, 4, c}, (curData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
nextData = train.get(i + 1);
if (category.equals(PriceCategory.ALL)) {
label.putScalar(new int[] {index, 0, c}, (nextData.getOpen() - minArray[1]) / (maxArray[1] - minArray[1]));
label.putScalar(new int[] {index, 1, c}, (nextData.getClose() - minArray[1]) / (maxArray[1] - minArray[1]));
label.putScalar(new int[] {index, 2, c}, (nextData.getLow() - minArray[2]) / (maxArray[2] - minArray[2]));
label.putScalar(new int[] {index, 3, c}, (nextData.getHigh() - minArray[3]) / (maxArray[3] - minArray[3]));
label.putScalar(new int[] {index, 4, c}, (nextData.getVolume() - minArray[4]) / (maxArray[4] - minArray[4]));
} else {
label.putScalar(new int[]{index, 0, c}, feedLabel(nextData));
}
curData = nextData;
}
if (exampleStartOffsets.size() == 0) break;
}
return new DataSet(input, label);
}
示例8: State
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
public State(
final int id,
final double timestamp,
final double x,
final double y,
final double orientation,
final double velocityX,
final double velocityY,
final double velocityR,
final TeamColor teamColor
) {
this(
id,
new SimpleDistribution(Nd4j.create(
new double[]{
timestamp,
x,
y,
orientation,
velocityX,
velocityY,
velocityR
},
new int[]{7, 1}), Nd4j.eye(7)),
teamColor);
}
示例9: ax
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
@Override
public double[] ax(double[] x, double[] y) {
// Nd4j.getBlasWrapper().level2().gemv() crashes.
// Use gemm for now.
int m = nrows();
int n = ncols();
INDArray ndx = Nd4j.create(x, new int[]{n, 1});
INDArray ndy = Nd4j.gemm(A, ndx, false, false);
for (int i = 0; i < m; i++) {
y[i] = ndy.getDouble(i);
}
return y;
}
示例10: axpy
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
@Override
public double[] axpy(double[] x, double[] y) {
// Nd4j.getBlasWrapper().level2().gemv() crashes.
// Use gemm for now.
int m = nrows();
int n = ncols();
INDArray ndx = Nd4j.create(x, new int[]{n, 1});
INDArray ndy = Nd4j.gemm(A, ndx, false, false);
for (int i = 0; i < m; i++) {
y[i] += ndy.getDouble(i);
}
return y;
}
示例11: binomial
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
private INDArray binomial(INDArray x, Random rng) {
INDArray y = Nd4j.create(new double[x.rows() * x.columns()], new int[] { x.rows(), x.columns() });
for (int i = 0; i < x.rows(); i++) {
for (int j = 0; j < x.columns(); j++) { y.put(i, j, RandomGenerator.binomial(1, x.getDouble(i, j), rng)); }
}
return y;
}
示例12: atx
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
@Override
public double[] atx(double[] x, double[] y) {
// Nd4j.getBlasWrapper().level2().gemv() crashes.
// Use gemm for now.
int m = nrows();
int n = ncols();
INDArray ndx = Nd4j.create(x, new int[]{m, 1});
INDArray ndy = Nd4j.gemm(A, ndx, true, false);
for (int i = 0; i < n; i++) {
y[i] = ndy.getDouble(i);
}
return y;
}
示例13: State
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
public State(
final double timestamp,
final double x,
final double y,
final double z
) {
this(new SimpleDistribution(Nd4j.create(
new double[]{
timestamp,
x,
y,
z
},
new int[]{4, 1}), Nd4j.eye(4)));
}
示例14: atxpy
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
@Override
public double[] atxpy(double[] x, double[] y, double b) {
// Nd4j.getBlasWrapper().level2().gemv() crashes.
// Use gemm for now.
int m = nrows();
int n = ncols();
INDArray ndx = Nd4j.create(x, new int[]{m, 1});
INDArray ndy = Nd4j.gemm(A, ndx, true, false);
for (int i = 0; i < n; i++) {
y[i] = b * y[i] + ndy.getDouble(i);
}
return y;
}
示例15: outputBinomial
import org.nd4j.linalg.factory.Nd4j; //导入方法依赖的package包/类
public INDArray outputBinomial (INDArray X) {
INDArray out = output(X);
INDArray y = Nd4j.create(new double[out.rows() * out.columns()], new int[] { out.rows(), out.columns() });
for (int i = 0; i < out.rows(); i++) {
for (int j = 0; j < out.columns(); j++) {
double value = RandomGenerator.binomial(1, out.getDouble(i, j), rng);
y.put(i, j, Nd4j.scalar(value));
}
}
return y;
}