当前位置: 首页>>代码示例>>Java>>正文


Java INDArray.addi方法代码示例

本文整理汇总了Java中org.nd4j.linalg.api.ndarray.INDArray.addi方法的典型用法代码示例。如果您正苦于以下问题:Java INDArray.addi方法的具体用法?Java INDArray.addi怎么用?Java INDArray.addi使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.api.ndarray.INDArray的用法示例。


在下文中一共展示了INDArray.addi方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: lexicalSubstituteAdd

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * Lexical Substitute task by Add Method
 * @param word target word
 * @param contexts list of given contexts
 * @param average average the context vectors of given contexts
 * @param top number of results return
 * @return a list of {@link Pair}
 */
public List<Pair<String, Double>> lexicalSubstituteAdd (String word, List<String> contexts, boolean average, Integer top) {
	top = MoreObjects.firstNonNull(top, 10);
	INDArray targetVec = getWordVector(word);
	INDArray ctxVec = zeros();
	int found = 0;
	for (String context : contexts) {
		if (!hasContext(context)) continue;
		found++;
		ctxVec.addi(getContextVector(context));
	}
	if (average && (found != 0)) ctxVec.divi(Nd4j.scalar(found));
	targetVec.addi(ctxVec);
	double norm = Math.sqrt(targetVec.mmul(targetVec.transpose()).getDouble(0));
	norm = norm == 0.0 ? 1.0 : norm;
	targetVec.divi(Nd4j.scalar(norm));
	INDArray scores = wordSimilarity(targetVec);
	List<Pair<String, Double>> list = new ArrayList<>(wordVocab.size());
	for (int i = 0; i < wordVocab.size(); i++) { list.add(new Pair<>(wordVocab.get(i), scores.getDouble(i))); }
	return list.stream().sorted((e1, e2) -> Double.valueOf(e2.getValue()).compareTo(Double.valueOf(e1.getValue()))).limit(top).collect(Collectors.toCollection(LinkedList::new));
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:29,代码来源:Word2Vecf.java

示例2: updateWeights

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
   public void updateWeights(INDArray remoteGradient) {
       Log.d(TAG, "Remote Gradient " + remoteGradient);
       Gradient gradient = new DefaultGradient(remoteGradient);
       Log.d(TAG, "Updating weights from server with gradient " + gradient.gradient().toString());
       // TODO Transform the remoteGradient flattened array into the map required by the network?
       Map<String, INDArray> netGradients = mNetwork.gradient().gradientForVariable();
       for (Map.Entry<String, INDArray> entry : netGradients.entrySet()) {
           Log.d(TAG, entry.getKey());
           for (int i : entry.getValue().shape()) {
               Log.d(TAG, "Shape " + i);
           }
           for (int i = 0; i < entry.getValue().shape().length; i++) {
               Log.d(TAG, "Size (" + i + ")" + entry.getValue().size(i));
           }
       }
       Log.d(TAG, "Updating weights with INDArray object");
       INDArray params = mNetwork.params(true);
       params.addi(remoteGradient);

       /*
0_W
Shape 2
Shape 10
0_b
Shape 1
Shape 10
1_W
Shape 10
Shape 1
1_b
Shape 1
Shape 1
Weights updated
        */

       mNetwork.update(gradient);
       Log.d(TAG, "Weights updated");
   }
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:40,代码来源:LinearModel.java

示例3: getContextVectorMean

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/** @return mean vector, built from contexts passed in */
public INDArray getContextVectorMean (List<String> contexts) {
	INDArray res = zeros();
	if (contexts == null || contexts.isEmpty()) return res;
	for (String context : contexts) res.addi(getContextVector(context));
	return res.div(Nd4j.scalar(contexts.size()));
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:8,代码来源:Word2Vecf.java

示例4: lexicalSubstituteAdaptive

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * Lexical Substitute task by Adaptive Method
 * @param word target word
 * @param contexts list of given contexts
 * @param parameter parameter to adjust the weight between word and contexts
 * @param top number of results return
 * @return a list of {@link Pair}
 */
public List<Pair<String, Double>> lexicalSubstituteAdaptive (String word, List<String> contexts, double parameter, Integer top) {
	top = MoreObjects.firstNonNull(top, 10);
	if (parameter < 0 && parameter > 1.0)
		parameter = 0.5; // set default value
	INDArray targetVec = getWordVector(word);
	INDArray ctxVec = zeros();
	int found = 0;
	for (String context : contexts) {
		if (hasContext(context)) {
			found++;
			ctxVec.addi(getContextVector(context));
		}
	}
	if (found != 0) ctxVec.divi(Nd4j.scalar(found));
	INDArray wscores = wordSimilarity(targetVec);
	if (wscores.minNumber().doubleValue() < 0.0) wscores.addi(wscores.minNumber().doubleValue() * (-1.0f)).divi(wscores.maxNumber());
	else wscores.divi(wscores.maxNumber());
	INDArray cscores = wordVectors.subRowVector(targetVec).mmul(ctxVec.transpose());
	if (cscores.minNumber().doubleValue() < 0.0) cscores.addi(cscores.minNumber().doubleValue() * (-1.0)).divi(cscores.maxNumber());
	else cscores.divi(cscores.maxNumber());
	INDArray scores = wscores.mul(1.0 - parameter).add(cscores.mul(parameter));
	scores.divi(scores.maxNumber());
	List<Pair<String, Double>> list = new ArrayList<>(wordVocab.size());
	for (int i = 0; i < getWordVocab().size(); i++) { list.add(new Pair<>(wordVocab.get(i), scores.getDouble(i))); }
	return list.stream().sorted((e1, e2) -> Double.valueOf(e2.getValue()).compareTo(Double.valueOf(e1.getValue()))).limit(top).collect(Collectors.toCollection(LinkedList::new));
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:35,代码来源:Word2Vecf.java

示例5: getWordVectorMean

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/** @return mean vector, built from words passed in */
public INDArray getWordVectorMean (List<String> words) {
	INDArray res = zeros();
	if (words == null || words.isEmpty()) return res;
	for (String word : words) res.addi(getWordVector(word));
	return res.div(Nd4j.scalar(words.size()));
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:8,代码来源:Word2Vec.java

示例6: updateWeights

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
@Override
public void updateWeights(INDArray remoteGradient) {
    Log.d(TAG, "Updating weights with INDArray object");
    INDArray params = model.params(true);
    params.addi(remoteGradient);
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:7,代码来源:IrisModel.java

示例7: getPar2HierVector

import org.nd4j.linalg.api.ndarray.INDArray; //导入方法依赖的package包/类
/**
 * base case: on a leaf hv = pv
 * on a non-leaf node with n children: hv = pv + k centroids of the n hv
 */
private static INDArray getPar2HierVector(WeightLookupTable<VocabWord> lookupTable, PatriciaTrie<String> trie, String node,
                                          int k, Map<String, INDArray> hvs, Method method) {
  if (hvs.containsKey(node)) {
    return hvs.get(node);
  }
  INDArray hv = lookupTable.vector(node);
  String[] split = node.split(REGEX);
  Collection<String> descendants = new HashSet<>();
  if (split.length == 2) {
    String separator = ".";
    String prefix = node.substring(0, node.indexOf(split[1])) + separator;

    SortedMap<String, String> sortedMap = trie.prefixMap(prefix);

    for (Map.Entry<String, String> entry : sortedMap.entrySet()) {
      if (prefix.lastIndexOf(separator) == entry.getKey().lastIndexOf(separator)) {
        descendants.add(entry.getValue());
      }
    }
  } else {
    descendants = Collections.emptyList();
  }
  if (descendants.size() == 0) {
    // just the pv
    hvs.put(node, hv);
    return hv;
  } else {
    INDArray chvs = Nd4j.zeros(descendants.size(), hv.columns());
    int i = 0;
    for (String desc : descendants) {
      // child hierarchical vector
      INDArray chv = getPar2HierVector(lookupTable, trie, desc, k, hvs, method);
      chvs.putRow(i, chv);
      i++;
    }

    double[][] centroids;
    if (chvs.rows() > k) {
      centroids = Par2HierUtils.getTruncatedVT(chvs, k);
    } else if (chvs.rows() == 1) {
      centroids = Par2HierUtils.getDoubles(chvs.getRow(0));
    } else {
      centroids = Par2HierUtils.getTruncatedVT(chvs, 1);
    }
    switch (method) {
      case CLUSTER:
        INDArray matrix = Nd4j.zeros(centroids.length + 1, hv.columns());
        matrix.putRow(0, hv);
        for (int c = 0; c < centroids.length; c++) {
          matrix.putRow(c + 1, Nd4j.create(centroids[c]));
        }
        hv = Nd4j.create(Par2HierUtils.getTruncatedVT(matrix, 1));
        break;
      case SUM:
        for (double[] centroid : centroids) {
          hv.addi(Nd4j.create(centroid));
        }
        break;
    }

    hvs.put(node, hv);
    return hv;
  }
}
 
开发者ID:tteofili,项目名称:par2hier,代码行数:69,代码来源:Par2HierUtils.java


注:本文中的org.nd4j.linalg.api.ndarray.INDArray.addi方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。