当前位置: 首页>>代码示例>>Java>>正文


Java Triple类代码示例

本文整理汇总了Java中org.nd4j.linalg.primitives.Triple的典型用法代码示例。如果您正苦于以下问题:Java Triple类的具体用法?Java Triple怎么用?Java Triple使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


Triple类属于org.nd4j.linalg.primitives包,在下文中一共展示了Triple类的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testBroadcastShapes

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Test
public void testBroadcastShapes(){
    //Test cases: in1Shape, in2Shape, shapeOf(op(in1,in2))
    List<Triple<int[],int[], int[]>> testCases = new ArrayList<>();
    testCases.add(new Triple<>(new int[]{3,1}, new int[]{1,4}, new int[]{3,4}));
    testCases.add(new Triple<>(new int[]{3,1}, new int[]{3,4}, new int[]{3,4}));
    testCases.add(new Triple<>(new int[]{3,4}, new int[]{1,4}, new int[]{3,4}));
    testCases.add(new Triple<>(new int[]{3,4,1}, new int[]{1,1,5}, new int[]{3,4,5}));
    testCases.add(new Triple<>(new int[]{3,4,1}, new int[]{3,1,5}, new int[]{3,4,5}));
    testCases.add(new Triple<>(new int[]{3,1,5}, new int[]{1,4,1}, new int[]{3,4,5}));
    testCases.add(new Triple<>(new int[]{3,1,5}, new int[]{1,4,5}, new int[]{3,4,5}));
    testCases.add(new Triple<>(new int[]{3,1,5}, new int[]{3,4,5}, new int[]{3,4,5}));
    testCases.add(new Triple<>(new int[]{3,1,1,1}, new int[]{1,4,5,6}, new int[]{3,4,5,6}));
    testCases.add(new Triple<>(new int[]{1,1,1,6}, new int[]{3,4,5,6}, new int[]{3,4,5,6}));
    testCases.add(new Triple<>(new int[]{1,4,5,1}, new int[]{3,1,1,6}, new int[]{3,4,5,6}));
    testCases.add(new Triple<>(new int[]{1,6}, new int[]{3,4,5,1}, new int[]{3,4,5,6}));

    for(Triple<int[], int[], int[]> t : testCases){
        int[] x = t.getFirst();
        int[] y = t.getSecond();
        int[] exp = t.getThird();

        int[] act = Shape.broadcastOutputShape(x,y);
        assertArrayEquals(exp,act);
    }
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:27,代码来源:ShapeTests.java

示例2: getGraphInfo

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
private TrainModuleUtils.GraphInfo getGraphInfo() {
    Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig();
    if (conf == null) {
        return null;
    }

    if (conf.getFirst() != null) {
        return TrainModuleUtils.buildGraphInfo(conf.getFirst());
    } else if (conf.getSecond() != null) {
        return TrainModuleUtils.buildGraphInfo(conf.getSecond());
    } else if (conf.getThird() != null) {
        return TrainModuleUtils.buildGraphInfo(conf.getThird());
    } else {
        return null;
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:TrainModule.java

示例3: Word2VecChange

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
public Word2VecChange(List<Triple<Integer, Integer, Integer>> counterMap, Word2VecParam param) {
    Iterator<Triple<Integer, Integer, Integer>> iter = counterMap.iterator();
    while (iter.hasNext()) {
        Triple<Integer, Integer, Integer> next = iter.next();
        Integer point = next.getFirst();
        Integer index = next.getSecond();

        Set<INDArray> changes = this.changes.get(point);
        if (changes == null) {
            changes = new HashSet<>();
            this.changes.put(point, changes);
        }

        changes.add(param.getWeights().getSyn1().slice(index));

    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:Word2VecChange.java

示例4: skipGram

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
/**
 * Train via skip gram
 * @param i the current word
 * @param sentence the sentence to train on
 * @param b
 * @param alpha the learning rate
 */
public void skipGram(Word2VecParam param, int i, List<VocabWord> sentence, int b, double alpha,
                List<Triple<Integer, Integer, Integer>> changed) {

    final VocabWord word = sentence.get(i);
    int window = param.getWindow();
    if (word != null && !sentence.isEmpty()) {
        int end = window * 2 + 1 - b;
        for (int a = b; a < end; a++) {
            if (a != window) {
                int c = i - window + a;
                if (c >= 0 && c < sentence.size()) {
                    VocabWord lastWord = sentence.get(c);
                    iterateSample(param, word, lastWord, alpha, changed);
                }
            }
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:SentenceBatch.java

示例5: testSliceGradient

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Test
public void testSliceGradient() {
    Nd4j.getRandom().setSeed(12345);

    //Order here: original shape, begin, size
    List<Triple<int[], int[], int[]>> testCases = new ArrayList<>();
    testCases.add(new Triple<>(new int[]{3, 4}, new int[]{0, 0}, new int[]{3, 4}));
    testCases.add(new Triple<>(new int[]{3, 4}, new int[]{1, 1}, new int[]{3, 4}));
    testCases.add(new Triple<>(new int[]{3, 4}, new int[]{1, 2}, new int[]{2, 3}));
    testCases.add(new Triple<>(new int[]{3, 4, 5}, new int[]{0, 0, 0}, new int[]{3, 4, 5}));
    testCases.add(new Triple<>(new int[]{3, 4, 5}, new int[]{1, 1, 1}, new int[]{2, 3, 4}));
    testCases.add(new Triple<>(new int[]{3, 4, 5}, new int[]{1, 0, 2}, new int[]{3, 3, 4}));

    for (int i = 0; i < testCases.size(); i++) {
        Triple<int[], int[], int[]> t = testCases.get(i);
        int[] os = t.getFirst();
        int[] b = t.getSecond();
        int[] e = t.getThird();
        INDArray arr = Nd4j.rand(os);

        SameDiff sd = SameDiff.create();
        SDVariable in = sd.var("in", arr);
        SDVariable slice = sd.slice(in, b, e);
        SDVariable stdev = sd.standardDeviation(slice, true);

        String msg = "i=" + i + ": inShape=" + Arrays.toString(os) + ", begin=" + Arrays.toString(b) + ", end=" + Arrays.toString(e);
        log.info("Starting test: " + msg);
        GradCheckUtil.checkGradients(sd);
    }
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:31,代码来源:GradCheckMisc.java

示例6: nOutReplace

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
private Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, WeightInit schemeNext, Distribution dist,
                Distribution distNext) {
    editedLayers.add(layerNum);
    Triple<Integer, Pair<WeightInit, Distribution>, Pair<WeightInit, Distribution>> t =
                    new Triple(nOut, new Pair<>(scheme, dist),
                                    new Pair<>(schemeNext, distNext));
    editedLayersMap.put(layerNum, t);
    return this;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:10,代码来源:TransferLearning.java

示例7: call

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Override
public Word2VecChange call(Word2VecFuncCall sentence) throws Exception {
    Word2VecParam param = sentence.getParam().getValue();
    List<Triple<Integer, Integer, Integer>> changed = new ArrayList<>();
    double alpha = Math.max(param.getMinAlpha(),
                    param.getAlpha() * (1 - (1.0 * sentence.getWordsSeen() / (double) param.getTotalWords())));

    trainSentence(param, sentence.getSentence(), alpha, changed);
    return new Word2VecChange(changed, param);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:11,代码来源:SentenceBatch.java

示例8: trainSentence

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
/**
 * Train on a list of vocab words
 * @param sentence the list of vocab words to train on
 */
public void trainSentence(Word2VecParam param, final List<VocabWord> sentence, double alpha,
                List<Triple<Integer, Integer, Integer>> changed) {
    if (sentence != null && !sentence.isEmpty()) {
        for (int i = 0; i < sentence.size(); i++) {
            VocabWord vocabWord = sentence.get(i);
            if (vocabWord != null && vocabWord.getWord().endsWith("STOP")) {
                nextRandom.set(nextRandom.get() * 25214903917L + 11);
                skipGram(param, i, sentence, (int) nextRandom.get() % param.getWindow(), alpha, changed);
            }
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:SentenceBatch.java

示例9: testCalls

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Test
public void testCalls(){

    CustomDropout d1 = new CustomDropout();
    CustomDropout d2 = new CustomDropout();

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .list()
            .layer(new DenseLayer.Builder().nIn(4).nOut(3).dropOut(d1).build())
            .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).dropOut(d2).nIn(3).nOut(3).build())
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    List<DataSet> l = new ArrayList<>();
    l.add(new DataSet(Nd4j.rand(5,4), Nd4j.rand(5,3)));
    l.add(new DataSet(Nd4j.rand(5,4), Nd4j.rand(5,3)));
    l.add(new DataSet(Nd4j.rand(5,4), Nd4j.rand(5,3)));

    DataSetIterator iter = new ExistingDataSetIterator(l);

    net.fit(iter);
    net.fit(iter);

    List<Triple<Integer,Integer,Boolean>> expList = Arrays.asList(
            new Triple<>(0, 0, false),
            new Triple<>(1, 0, false),
            new Triple<>(2, 0, false),
            new Triple<>(3, 1, false),
            new Triple<>(4, 1, false),
            new Triple<>(5, 1, false));

    assertEquals(expList, d1.getAllCalls());
    assertEquals(expList, d2.getAllCalls());


    d1 = new CustomDropout();
    d2 = new CustomDropout();
    ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder()
            .graphBuilder()
            .addInputs("in")
            .addLayer("0", new DenseLayer.Builder().nIn(4).nOut(3).dropOut(d1).build(), "in")
            .addLayer("1", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).dropOut(d2).nIn(3).nOut(3).build(), "0")
            .setOutputs("1")
            .build();

    ComputationGraph net2 = new ComputationGraph(conf2);
    net2.init();

    net2.fit(iter);
    net2.fit(iter);

    assertEquals(expList, d1.getAllCalls());
    assertEquals(expList, d2.getAllCalls());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:56,代码来源:TestDropout.java

示例10: applyDropout

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Override
public INDArray applyDropout(INDArray inputActivations, int iteration, int epoch, boolean inPlace) {
    allCalls.add(new Triple<>(iteration, epoch, inPlace));
    return inputActivations;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:6,代码来源:TestDropout.java

示例11: call

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Override
public Triple<VocabWord, VocabWord, Double> call(Triple<String, String, Double> v1) throws Exception {
    return new Triple<>((VocabWord) vocab.getValue().wordFor(v1.getFirst()),
                    (VocabWord) vocab.getValue().wordFor(v1.getSecond()), v1.getThird());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:6,代码来源:VocabWordPairs.java

示例12: call

import org.nd4j.linalg.primitives.Triple; //导入依赖的package包/类
@Override
public GloveChange call(Triple<VocabWord, VocabWord, Double> pair) throws Exception {
    return null;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:5,代码来源:GlovePerformer.java


注:本文中的org.nd4j.linalg.primitives.Triple类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。