本文整理汇总了Java中org.nd4j.linalg.primitives.Triple类的典型用法代码示例。如果您正苦于以下问题:Java Triple类的具体用法?Java Triple怎么用?Java Triple使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
Triple类属于org.nd4j.linalg.primitives包,在下文中一共展示了Triple类的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: testBroadcastShapes
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Test
public void testBroadcastShapes(){
//Test cases: in1Shape, in2Shape, shapeOf(op(in1,in2))
List<Triple<int[],int[], int[]>> testCases = new ArrayList<>();
testCases.add(new Triple<>(new int[]{3,1}, new int[]{1,4}, new int[]{3,4}));
testCases.add(new Triple<>(new int[]{3,1}, new int[]{3,4}, new int[]{3,4}));
testCases.add(new Triple<>(new int[]{3,4}, new int[]{1,4}, new int[]{3,4}));
testCases.add(new Triple<>(new int[]{3,4,1}, new int[]{1,1,5}, new int[]{3,4,5}));
testCases.add(new Triple<>(new int[]{3,4,1}, new int[]{3,1,5}, new int[]{3,4,5}));
testCases.add(new Triple<>(new int[]{3,1,5}, new int[]{1,4,1}, new int[]{3,4,5}));
testCases.add(new Triple<>(new int[]{3,1,5}, new int[]{1,4,5}, new int[]{3,4,5}));
testCases.add(new Triple<>(new int[]{3,1,5}, new int[]{3,4,5}, new int[]{3,4,5}));
testCases.add(new Triple<>(new int[]{3,1,1,1}, new int[]{1,4,5,6}, new int[]{3,4,5,6}));
testCases.add(new Triple<>(new int[]{1,1,1,6}, new int[]{3,4,5,6}, new int[]{3,4,5,6}));
testCases.add(new Triple<>(new int[]{1,4,5,1}, new int[]{3,1,1,6}, new int[]{3,4,5,6}));
testCases.add(new Triple<>(new int[]{1,6}, new int[]{3,4,5,1}, new int[]{3,4,5,6}));
for(Triple<int[], int[], int[]> t : testCases){
int[] x = t.getFirst();
int[] y = t.getSecond();
int[] exp = t.getThird();
int[] act = Shape.broadcastOutputShape(x,y);
assertArrayEquals(exp,act);
}
}
示例2: getGraphInfo
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
private TrainModuleUtils.GraphInfo getGraphInfo() {
Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig();
if (conf == null) {
return null;
}
if (conf.getFirst() != null) {
return TrainModuleUtils.buildGraphInfo(conf.getFirst());
} else if (conf.getSecond() != null) {
return TrainModuleUtils.buildGraphInfo(conf.getSecond());
} else if (conf.getThird() != null) {
return TrainModuleUtils.buildGraphInfo(conf.getThird());
} else {
return null;
}
}
示例3: Word2VecChange
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
public Word2VecChange(List<Triple<Integer, Integer, Integer>> counterMap, Word2VecParam param) {
Iterator<Triple<Integer, Integer, Integer>> iter = counterMap.iterator();
while (iter.hasNext()) {
Triple<Integer, Integer, Integer> next = iter.next();
Integer point = next.getFirst();
Integer index = next.getSecond();
Set<INDArray> changes = this.changes.get(point);
if (changes == null) {
changes = new HashSet<>();
this.changes.put(point, changes);
}
changes.add(param.getWeights().getSyn1().slice(index));
}
}
示例4: skipGram
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
/**
* Train via skip gram
* @param i the current word
* @param sentence the sentence to train on
* @param b
* @param alpha the learning rate
*/
public void skipGram(Word2VecParam param, int i, List<VocabWord> sentence, int b, double alpha,
List<Triple<Integer, Integer, Integer>> changed) {
final VocabWord word = sentence.get(i);
int window = param.getWindow();
if (word != null && !sentence.isEmpty()) {
int end = window * 2 + 1 - b;
for (int a = b; a < end; a++) {
if (a != window) {
int c = i - window + a;
if (c >= 0 && c < sentence.size()) {
VocabWord lastWord = sentence.get(c);
iterateSample(param, word, lastWord, alpha, changed);
}
}
}
}
}
示例5: testSliceGradient
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Test
public void testSliceGradient() {
Nd4j.getRandom().setSeed(12345);
//Order here: original shape, begin, size
List<Triple<int[], int[], int[]>> testCases = new ArrayList<>();
testCases.add(new Triple<>(new int[]{3, 4}, new int[]{0, 0}, new int[]{3, 4}));
testCases.add(new Triple<>(new int[]{3, 4}, new int[]{1, 1}, new int[]{3, 4}));
testCases.add(new Triple<>(new int[]{3, 4}, new int[]{1, 2}, new int[]{2, 3}));
testCases.add(new Triple<>(new int[]{3, 4, 5}, new int[]{0, 0, 0}, new int[]{3, 4, 5}));
testCases.add(new Triple<>(new int[]{3, 4, 5}, new int[]{1, 1, 1}, new int[]{2, 3, 4}));
testCases.add(new Triple<>(new int[]{3, 4, 5}, new int[]{1, 0, 2}, new int[]{3, 3, 4}));
for (int i = 0; i < testCases.size(); i++) {
Triple<int[], int[], int[]> t = testCases.get(i);
int[] os = t.getFirst();
int[] b = t.getSecond();
int[] e = t.getThird();
INDArray arr = Nd4j.rand(os);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", arr);
SDVariable slice = sd.slice(in, b, e);
SDVariable stdev = sd.standardDeviation(slice, true);
String msg = "i=" + i + ": inShape=" + Arrays.toString(os) + ", begin=" + Arrays.toString(b) + ", end=" + Arrays.toString(e);
log.info("Starting test: " + msg);
GradCheckUtil.checkGradients(sd);
}
}
示例6: nOutReplace
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
private Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, WeightInit schemeNext, Distribution dist,
Distribution distNext) {
editedLayers.add(layerNum);
Triple<Integer, Pair<WeightInit, Distribution>, Pair<WeightInit, Distribution>> t =
new Triple(nOut, new Pair<>(scheme, dist),
new Pair<>(schemeNext, distNext));
editedLayersMap.put(layerNum, t);
return this;
}
示例7: call
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Override
public Word2VecChange call(Word2VecFuncCall sentence) throws Exception {
Word2VecParam param = sentence.getParam().getValue();
List<Triple<Integer, Integer, Integer>> changed = new ArrayList<>();
double alpha = Math.max(param.getMinAlpha(),
param.getAlpha() * (1 - (1.0 * sentence.getWordsSeen() / (double) param.getTotalWords())));
trainSentence(param, sentence.getSentence(), alpha, changed);
return new Word2VecChange(changed, param);
}
示例8: trainSentence
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
/**
* Train on a list of vocab words
* @param sentence the list of vocab words to train on
*/
public void trainSentence(Word2VecParam param, final List<VocabWord> sentence, double alpha,
List<Triple<Integer, Integer, Integer>> changed) {
if (sentence != null && !sentence.isEmpty()) {
for (int i = 0; i < sentence.size(); i++) {
VocabWord vocabWord = sentence.get(i);
if (vocabWord != null && vocabWord.getWord().endsWith("STOP")) {
nextRandom.set(nextRandom.get() * 25214903917L + 11);
skipGram(param, i, sentence, (int) nextRandom.get() % param.getWindow(), alpha, changed);
}
}
}
}
示例9: testCalls
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Test
public void testCalls(){
CustomDropout d1 = new CustomDropout();
CustomDropout d2 = new CustomDropout();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(4).nOut(3).dropOut(d1).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).dropOut(d2).nIn(3).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
List<DataSet> l = new ArrayList<>();
l.add(new DataSet(Nd4j.rand(5,4), Nd4j.rand(5,3)));
l.add(new DataSet(Nd4j.rand(5,4), Nd4j.rand(5,3)));
l.add(new DataSet(Nd4j.rand(5,4), Nd4j.rand(5,3)));
DataSetIterator iter = new ExistingDataSetIterator(l);
net.fit(iter);
net.fit(iter);
List<Triple<Integer,Integer,Boolean>> expList = Arrays.asList(
new Triple<>(0, 0, false),
new Triple<>(1, 0, false),
new Triple<>(2, 0, false),
new Triple<>(3, 1, false),
new Triple<>(4, 1, false),
new Triple<>(5, 1, false));
assertEquals(expList, d1.getAllCalls());
assertEquals(expList, d2.getAllCalls());
d1 = new CustomDropout();
d2 = new CustomDropout();
ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(4).nOut(3).dropOut(d1).build(), "in")
.addLayer("1", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).dropOut(d2).nIn(3).nOut(3).build(), "0")
.setOutputs("1")
.build();
ComputationGraph net2 = new ComputationGraph(conf2);
net2.init();
net2.fit(iter);
net2.fit(iter);
assertEquals(expList, d1.getAllCalls());
assertEquals(expList, d2.getAllCalls());
}
示例10: applyDropout
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Override
public INDArray applyDropout(INDArray inputActivations, int iteration, int epoch, boolean inPlace) {
allCalls.add(new Triple<>(iteration, epoch, inPlace));
return inputActivations;
}
示例11: call
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Override
public Triple<VocabWord, VocabWord, Double> call(Triple<String, String, Double> v1) throws Exception {
return new Triple<>((VocabWord) vocab.getValue().wordFor(v1.getFirst()),
(VocabWord) vocab.getValue().wordFor(v1.getSecond()), v1.getThird());
}
示例12: call
import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Override
public GloveChange call(Triple<VocabWord, VocabWord, Double> pair) throws Exception {
return null;
}