当前位置: 首页>>代码示例>>Java>>正文


Java Nd4j类代码示例

本文整理汇总了Java中org.nd4j.linalg.factory.Nd4j的典型用法代码示例。如果您正苦于以下问题:Java Nd4j类的具体用法?Java Nd4j怎么用?Java Nd4j使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


Nd4j类属于org.nd4j.linalg.factory包,在下文中一共展示了Nd4j类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: fetch

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
@Override
public void fetch(int numExamples) {
	float[][] featureData = new float[numExamples][0];
	float[][] labelData = new float[numExamples][0];

	int examplesRead = 0;

	for (; examplesRead < numExamples; examplesRead++) {
		if (cursor + examplesRead >= m_allFileNames.size()) {
			break;
		}
		Entry<String, String> entry = m_allFileNames.get(cursor + examplesRead);

		featureData[examplesRead] = imageFileNameToMnsitFormat(entry.getValue());
		labelData[examplesRead] = toLabelArray(entry.getKey());
	}
	cursor += examplesRead;

	INDArray features = Nd4j.create(featureData);
	INDArray labels = Nd4j.create(labelData);
	curr = new DataSet(features, labels);
}
 
开发者ID:braeunlich,项目名称:anagnostes,代码行数:23,代码来源:NumbersDataFetcher.java

示例2: nd4JExample

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
public void nd4JExample() {
    double[] A = {
        0.1950, 0.0311,
        0.3588, 0.2203,
        0.1716, 0.5931,
        0.2105, 0.3242};

    double[] B = {
        0.0502, 0.9823, 0.9472,
        0.5732, 0.2694, 0.916};

    
    INDArray aINDArray = Nd4j.create(A,new int[]{4,2},'c');
    INDArray bINDArray = Nd4j.create(B,new int[]{2,3},'c');
    
    INDArray cINDArray;
    cINDArray = aINDArray.mmul(bINDArray);
    for(int i=0; i<cINDArray.rows(); i++) {
        System.out.println(cINDArray.getRow(i));
    }
}
 
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:22,代码来源:MathExamples.java

示例3: readTestData

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
private static List<TrainingData> readTestData(String fn) {
    int[] shape = { 3, 1 };
    List<TrainingData> trainingDataSet = new ArrayList<>();
    try {
        CSVReader reader = new CSVReader(new FileReader(fn));
        String[] row;
        while ((row = reader.readNext()) != null) {
            int type = Integer.parseInt(row[0]);
            double f1 = Double.parseDouble(row[1]);
            double f2 = Double.parseDouble(row[2]);
            double f3 = Double.parseDouble(row[3]);
            TrainingData trainingData = new TrainingData();
            trainingData.input = Nd4j.create(new double[] { f1, f2, f3 }, shape);
            trainingData.output = Nd4j.zeros(shape);
            trainingData.output.putScalar(type, (double) 1);
            trainingDataSet.add(trainingData);
        }
    } catch (java.io.IOException e) {
    }
    return trainingDataSet;
}
 
开发者ID:apuder,项目名称:ActivityMonitor,代码行数:22,代码来源:Main.java

示例4: collisionAvoidance

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
/**
 * Returns a velocity vector for a specific boid, given a set of other boids and scaling factor,
 * such that the boid will avoid collisions.
 *
 * @param currentBoid   The boid to compute the velocity vector for.
 * @param otherBoids    A set of {@link Boid}s.
 * @param scalingFactor The tuning parameter for the collision avoidance rule. High values result
 *                      in large vectors.
 * @return A velocity vector representing the suggested velocity for the specified boid such that
 * the specific boid steers away from other boids in the flock.
 */
private INDArray collisionAvoidance(
    final O currentBoid,
    final Set<? extends O> otherBoids,
    final double scalingFactor
) {
  return otherBoids.stream()
      .filter(boid -> !boid.equals(currentBoid))
      .map(O::getPosition)
      .reduce(
          Nd4j.create(currentBoid.getPosition().shape()),
          INDArray::add)
      .div(otherBoids.size())
      .sub(currentBoid.getPosition())
      .mul(scalingFactor);
}
 
开发者ID:delta-leonis,项目名称:algieba,代码行数:27,代码来源:SimpleBoid.java

示例5: velocityMatching

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
/**
 * Returns a velocity vector for a specific boid, given a set of other boids and a range, such
 * that the boid will match its velocity vector with other boids within that range.
 *
 * @param currentBoid The boid to compute the velocity vector for.
 * @param otherBoids  A set of {@link Boid}s.
 * @param range       The range within which boids will be taken into account in computing the
 *                    velocity vector.
 * @return A velocity vector representing the suggested velocity for the specified boid such that
 * the specific boid matches its velocity vector with other boids within its range.
 */
private INDArray velocityMatching(
    final O currentBoid,
    final Set<? extends O> otherBoids,
    final double range
) {
  return otherBoids.stream()
      .filter(boid -> !boid.equals(currentBoid))
      .map(O::getVelocity)
      .filter(boid ->
          boid.distance2(currentBoid.getPosition()) < range)
      .reduce(
          Nd4j.create(currentBoid.getPosition().shape()),
          (velocity, nextBoid)
              -> velocity.sub(
              nextBoid.sub(
                  currentBoid.getVelocity())));
}
 
开发者ID:delta-leonis,项目名称:algieba,代码行数:29,代码来源:SimpleBoid.java

示例6: apply

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
/**
 * Computes the filtered {@link Distribution} (mean and covariance) using the supplied state-,
 * measurement-, and control-transition matrices, and the supplied control input, process noise,
 * observation noise, and current state.
 *
 * @param stateTransitionMatrix       The state-transition matrix used in projecting the {@link
 *                                    Distribution}.
 * @param measurementTransitionMatrix The measurement-transition matrix used in computing the
 *                                    measurement based on the supplied state.
 * @param controlTransitionMatrix     The control-transition matrix used in computing the effect
 *                                    of the control input on the resultant {@link Distribution}.
 * @param controlInput                The control input.
 * @param processCovariance           the process covariance
 * @param observationNoise            The {@link Distribution} of observation noise
 * @param state                       The {@link Distribution} of the latest state.
 * @return The filtered state.
 */
public Distribution apply(
    final INDArray stateTransitionMatrix,
    final INDArray measurementTransitionMatrix,
    final INDArray controlTransitionMatrix,
    final INDArray controlInput,
    final INDArray processCovariance,
    final Distribution observationNoise,
    final Distribution state
) {
  return apply((stateMean, controlVector) ->
          stateTransitionMatrix
              .mmul(stateMean)
              .add(controlTransitionMatrix
                  .mmul(controlTransitionMatrix.mmul(controlVector)))
              .add(processCovariance),
      measurementTransitionMatrix::mul,
      controlInput,
      Nd4j.eye(stateTransitionMatrix.rows()),
      Nd4j.eye(measurementTransitionMatrix.rows()),
      processCovariance,
      observationNoise,
      state);
}
 
开发者ID:delta-leonis,项目名称:algieba,代码行数:41,代码来源:ExtendedKalmanFilter.java

示例7: getTrainingData

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
@Override
public FederatedDataSet getTrainingData() {
    Random rand = new Random(seed);
    double[] sum = new double[N_SAMPLES];
    double[] input1 = new double[N_SAMPLES];
    double[] input2 = new double[N_SAMPLES];
    for (int i = 0; i < N_SAMPLES; i++) {
        input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        sum[i] = input1[i] + input2[i];
    }
    INDArray inputNDArray1 = Nd4j.create(input1, new int[]{N_SAMPLES, 1});
    INDArray inputNDArray2 = Nd4j.create(input2, new int[]{N_SAMPLES, 1});
    INDArray inputNDArray = Nd4j.hstack(inputNDArray1, inputNDArray2);
    INDArray outPut = Nd4j.create(sum, new int[]{N_SAMPLES, 1});
    DataSet dataSet = new DataSet(inputNDArray, outPut);
    dataSet.shuffle();
    return new FederatedDataSetImpl(dataSet);
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:20,代码来源:SumDataSource.java

示例8: getTestData

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
@Override
public FederatedDataSet getTestData() {
    Random rand = new Random(seed);
    int numSamples = N_SAMPLES/10;
    double[] sum = new double[numSamples];
    double[] input1 = new double[numSamples];
    double[] input2 = new double[numSamples];
    for (int i = 0; i < numSamples; i++) {
        input1[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        input2[i] = MIN_RANGE + (MAX_RANGE - MIN_RANGE) * rand.nextDouble();
        sum[i] = input1[i] + input2[i];
    }
    INDArray inputNDArray1 = Nd4j.create(input1, new int[]{numSamples, 1});
    INDArray inputNDArray2 = Nd4j.create(input2, new int[]{numSamples, 1});
    INDArray inputNDArray = Nd4j.hstack(inputNDArray1, inputNDArray2);
    INDArray outPut = Nd4j.create(sum, new int[]{numSamples, 1});
    return new FederatedDataSetImpl(new DataSet(inputNDArray, outPut));
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:19,代码来源:SumDataSource.java

示例9: testSVDPCA

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
@Test
public void testSVDPCA() throws Exception {
  double[][] data = getDoubles();

  Map<String, INDArray> weightTable = new TreeMap<>();
  Random r = new Random();
  for (double[] d : data) {
    byte[] bytes = new byte[10];
    r.nextBytes(bytes);
    weightTable.put(new String(bytes), Nd4j.create(d));
  }
  Map<String, INDArray> svdPCA = Par2HierUtils.svdPCA(weightTable, 2);
  assertEquals(weightTable.size(), svdPCA.size());
  for (Map.Entry<String, INDArray> e : svdPCA.entrySet()) {
    assertEquals(2, e.getValue().columns());
    assertNotNull(weightTable.get(e.getKey()));
  }
}
 
开发者ID:tteofili,项目名称:par2hier,代码行数:19,代码来源:Par2HierUtilsTest.java

示例10: State

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
public State(
    final int id,
    final double timestamp,
    final double x,
    final double y,
    final double orientation,
    final TeamColor teamColor
) {
  this(
      id,
      new SimpleDistribution(Nd4j.create(
          new double[]{
              timestamp,
              x,
              y,
              orientation
          },
          new int[]{4, 1}), Nd4j.eye(4)),
      teamColor);
}
 
开发者ID:delta-leonis,项目名称:subra,代码行数:21,代码来源:Player.java

示例11: strategize

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
/**
 * @param gameBuffer A {@link List} of {@link Tuple2} containing a game state with the
 *                   corresponding formation, and a mapping from {@link Player} to a vector
 *                   representing the difference between the agent's current position and the
 *                   desired position.
 * @return The {@link Strategy.Supplier strategy} which minimizes the difference between {@link
 * Player agent} positions and their {@link Formation} positions.
 */
public Strategy.Supplier strategize(
    final List<Tuple2<F, Map<PlayerIdentity, Tuple2<Player, INDArray>>>> gameBuffer
) {
  return () ->
      gameBuffer.get(0).getT1().getPlayers().stream()
          .collect(Collectors.toMap(
              Player::getIdentity,
              player -> new PlayerCommand.State(
                  Nd4j.vstack(Vectors.rotatePlanarCartesian(Nd4j.vstack(
                      this.computeCoordinateMagnitude(gameBuffer, player, 0),
                      this.computeCoordinateMagnitude(gameBuffer, player, 1)),
                      -1 * player.getOrientation()),
                      this.computeCoordinateMagnitude(gameBuffer, player, 2)),
                  0,
                  0,
                  0)));
}
 
开发者ID:delta-leonis,项目名称:subra,代码行数:26,代码来源:PSDFormationDeducer.java

示例12: apply

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
@Override
public Publisher<PositionFormation> apply(final Publisher<G> inputPublisher) {
  return Flux.from(inputPublisher)
      .map(game ->
          new PositionFormation(
              game.getPlayers().stream()
                  .filter(player -> player.getTeamColor().equals(this.getTeamColor()))
                  .collect(Collectors.toMap(
                      Player::getIdentity,
                      player ->
                          Nd4j.vstack(
                              game.getBall().getXY()
                                  .add(Transforms
                                      .unitVec(game.getBall().getXY().sub(player.getXY()))
                                      .mul(this.getDistanceFromBall())),
                              Nd4j.create(new double[]{
                                  Math.acos(Transforms.cosineSim(
                                      player.getXY(),
                                      game.getBall().getXY()))}))))));
}
 
开发者ID:delta-leonis,项目名称:subra,代码行数:21,代码来源:BallTrackerFormationDeducer.java

示例13: apply

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
@Override
public Publisher<MovingBall> apply(final Publisher<I> inputPublisher) {
  return Flux.from(inputPublisher)
      .scan(
          new MovingBall.State(0, 0, 0, 0, 0, 0, 0),
          (previousResult, input) -> input.getBalls().stream()
              .reduce((previousBall, foundBall) ->
                  new MovingBall.State(
                      this.kalmanFilter.apply(
                          getStateTransitionMatrix(
                              (foundBall.getTimestamp() - previousBall.getTimestamp())
                                  / 1000000d),
                          MEASUREMENT_TRANSITION_MATRIX,
                          CONTROL_TRANSITION_MATRIX,
                          Nd4j.zeros(7, 1),
                          PROCESS_COVARIANCE_MATRIX,
                          new SimpleDistribution(
                              previousBall.getState().getMean(),
                              MEASUREMENT_COVARIANCE_MATRIX),
                          foundBall.getState())))
              .orElse(new MovingBall.State(0, 0, 0, 0, 0, 0, 0)));
}
 
开发者ID:delta-leonis,项目名称:subra,代码行数:23,代码来源:MovingBallsKalmanFilter.java

示例14: getConceptVector

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
public INDArray getConceptVector(Concept c) {

		Tokenizer tok = SimpleTokenizer.INSTANCE;

		List<INDArray> vectors = new ArrayList<INDArray>();
		int countUnk = 0;
		for (String word : tok.tokenize(c.name.toLowerCase().trim())) {
			if (wordVectors.hasWord(word))
				vectors.add(wordVectors.getWordVectorMatrix(word));
			else {
				vectors.add(unkVector);
				countUnk++;
			}
		}
		if (vectors.size() == countUnk)
			return null; // all tokens unknown
		INDArray allVectors = Nd4j.vstack(vectors);

		// sum or mean is irrelevant for cosine similarity
		INDArray conceptVector = allVectors.mean(0);

		return conceptVector;
	}
 
开发者ID:UKPLab,项目名称:ijcnlp2017-cmaps,代码行数:24,代码来源:WordEmbeddingDistance.java

示例15: fromText

import org.nd4j.linalg.factory.Nd4j; //导入依赖的package包/类
private static Pair<List<String>, INDArray> fromText(String wordFilePath) throws IOException {
	BufferedReader reader = new BufferedReader(Common.asReaderUTF8Lenient(new FileInputStream(new File(wordFilePath))));
	String fstLine = reader.readLine();
	int vocabSize = Integer.parseInt(fstLine.split(" ")[0]);
	int layerSize = Integer.parseInt(fstLine.split(" ")[1]);
	List<String> wordVocab = Lists.newArrayList();
	INDArray wordVectors = Nd4j.create(vocabSize, layerSize);
	int n = 1;
	String line;
	while ((line = reader.readLine()) != null) {
		String[] values = line.split(" ");
		wordVocab.add(values[0]);
		Preconditions.checkArgument(layerSize == values.length - 1, "For file '%s', on line %s, layer size is %s, but found %s values in the word vector",
				wordFilePath, n, layerSize, values.length - 1); // Sanity check
		for (int d = 1; d < values.length; d++) wordVectors.putScalar(n - 1, d - 1, Float.parseFloat(values[d]));
		n++;
	}
	return new Pair<>(wordVocab, wordVectors);
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:20,代码来源:WordVectorSerializer.java


注:本文中的org.nd4j.linalg.factory.Nd4j类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。