当前位置: 首页>>代码示例>>Java>>正文


Java InMemoryLookupTable.setSyn0方法代码示例

本文整理汇总了Java中org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.setSyn0方法的典型用法代码示例。如果您正苦于以下问题:Java InMemoryLookupTable.setSyn0方法的具体用法?Java InMemoryLookupTable.setSyn0怎么用?Java InMemoryLookupTable.setSyn0使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable的用法示例。


在下文中一共展示了InMemoryLookupTable.setSyn0方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testParaVecSerialization1

import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; //导入方法依赖的package包/类
@Test
public void testParaVecSerialization1() throws Exception {
    VectorsConfiguration configuration = new VectorsConfiguration();
    configuration.setIterations(14123);
    configuration.setLayersSize(156);

    INDArray syn0 = Nd4j.rand(100, configuration.getLayersSize());
    INDArray syn1 = Nd4j.rand(100, configuration.getLayersSize());

    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();

    for (int i = 0; i < 100; i++) {
        VocabWord word = new VocabWord((float) i, "word_" + i);
        List<Integer> points = new ArrayList<>();
        List<Byte> codes = new ArrayList<>();
        int num = org.apache.commons.lang3.RandomUtils.nextInt(1, 20);
        for (int x = 0; x < num; x++) {
            points.add(org.apache.commons.lang3.RandomUtils.nextInt(1, 100000));
            codes.add(org.apache.commons.lang3.RandomUtils.nextBytes(10)[0]);
        }
        if (RandomUtils.nextInt(10) < 3) {
            word.markAsLabel(true);
        }
        word.setIndex(i);
        word.setPoints(points);
        word.setCodes(codes);
        cache.addToken(word);
        cache.addWordToIndex(i, word.getLabel());
    }

    InMemoryLookupTable<VocabWord> lookupTable =
                    (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
                                    .vectorLength(configuration.getLayersSize()).cache(cache).build();

    lookupTable.setSyn0(syn0);
    lookupTable.setSyn1(syn1);

    ParagraphVectors originalVectors =
                    new ParagraphVectors.Builder(configuration).vocabCache(cache).lookupTable(lookupTable).build();

    File tempFile = File.createTempFile("paravec", "tests");
    tempFile.deleteOnExit();

    WordVectorSerializer.writeParagraphVectors(originalVectors, tempFile);

    ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(tempFile);

    InMemoryLookupTable<VocabWord> restoredLookupTable =
                    (InMemoryLookupTable<VocabWord>) restoredVectors.getLookupTable();
    AbstractCache<VocabWord> restoredVocab = (AbstractCache<VocabWord>) restoredVectors.getVocab();

    assertEquals(restoredLookupTable.getSyn0(), lookupTable.getSyn0());
    assertEquals(restoredLookupTable.getSyn1(), lookupTable.getSyn1());

    for (int i = 0; i < cache.numWords(); i++) {
        assertEquals(cache.elementAtIndex(i).isLabel(), restoredVocab.elementAtIndex(i).isLabel());
        assertEquals(cache.wordAtIndex(i), restoredVocab.wordAtIndex(i));
        assertEquals(cache.elementAtIndex(i).getElementFrequency(),
                        restoredVocab.elementAtIndex(i).getElementFrequency(), 0.1f);
        List<Integer> originalPoints = cache.elementAtIndex(i).getPoints();
        List<Integer> restoredPoints = restoredVocab.elementAtIndex(i).getPoints();
        assertEquals(originalPoints.size(), restoredPoints.size());
        for (int x = 0; x < originalPoints.size(); x++) {
            assertEquals(originalPoints.get(x), restoredPoints.get(x));
        }

        List<Byte> originalCodes = cache.elementAtIndex(i).getCodes();
        List<Byte> restoredCodes = restoredVocab.elementAtIndex(i).getCodes();
        assertEquals(originalCodes.size(), restoredCodes.size());
        for (int x = 0; x < originalCodes.size(); x++) {
            assertEquals(originalCodes.get(x), restoredCodes.get(x));
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:75,代码来源:WordVectorSerializerTest.java

示例2: loadTxtVectors

import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; //导入方法依赖的package包/类
/**
 * This method can be used to load previously saved model from InputStream (like a HDFS-stream)
 *
 * Deprecation note: Please, consider using readWord2VecModel() or loadStaticModel() method instead
 *
 * @param stream InputStream that contains previously serialized model
 * @param skipFirstLine Set this TRUE if first line contains csv header, FALSE otherwise
 * @return
 * @throws IOException
 */
@Deprecated
public static WordVectors loadTxtVectors(@NonNull InputStream stream, boolean skipFirstLine) throws IOException {
    AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();

    BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
    String line = "";
    List<INDArray> arrays = new ArrayList<>();

    if (skipFirstLine)
        reader.readLine();

    while ((line = reader.readLine()) != null) {
        String[] split = line.split(" ");
        String word = split[0].replaceAll(whitespaceReplacement, " ");
        VocabWord word1 = new VocabWord(1.0, word);

        word1.setIndex(cache.numWords());

        cache.addToken(word1);

        cache.addWordToIndex(word1.getIndex(), word);

        cache.putVocabWord(word);

        float[] vector = new float[split.length - 1];

        for (int i = 1; i < split.length; i++) {
            vector[i - 1] = Float.parseFloat(split[i]);
        }

        INDArray row = Nd4j.create(vector);

        arrays.add(row);
    }

    InMemoryLookupTable<VocabWord> lookupTable =
                    (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
                                    .vectorLength(arrays.get(0).columns()).cache(cache).build();

    INDArray syn = Nd4j.vstack(arrays);

    Nd4j.clearNans(syn);
    lookupTable.setSyn0(syn);

    return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:57,代码来源:WordVectorSerializer.java

示例3: readSequenceVectors

import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; //导入方法依赖的package包/类
/**
 * This method loads previously saved SequenceVectors model from InputStream
 *
 * @param factory
 * @param stream
 * @param <T>
 * @return
 */
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(
                @NonNull SequenceElementFactory<T> factory, @NonNull InputStream stream) throws IOException {
    BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));

    // at first we load vectors configuration
    String line = reader.readLine();
    VectorsConfiguration configuration =
                    VectorsConfiguration.fromJson(new String(Base64.decodeBase64(line), "UTF-8"));

    AbstractCache<T> vocabCache = new AbstractCache.Builder<T>().build();


    List<INDArray> rows = new ArrayList<>();

    while ((line = reader.readLine()) != null) {
        if (line.isEmpty()) // skip empty line
            continue;
        ElementPair pair = ElementPair.fromEncodedJson(line);
        T element = factory.deserialize(pair.getObject());
        rows.add(Nd4j.create(pair.getVector()));
        vocabCache.addToken(element);
        vocabCache.addWordToIndex(element.getIndex(), element.getLabel());
    }

    reader.close();

    InMemoryLookupTable<T> lookupTable = (InMemoryLookupTable<T>) new InMemoryLookupTable.Builder<T>()
                    .vectorLength(rows.get(0).columns()).cache(vocabCache).build(); // fix: add vocab cache

    /*
     * INDArray syn0 = Nd4j.create(rows.size(), rows.get(0).columns()); for (int x = 0; x < rows.size(); x++) {
     * syn0.putRow(x, rows.get(x)); }
     */
    INDArray syn0 = Nd4j.vstack(rows);

    lookupTable.setSyn0(syn0);

    SequenceVectors<T> vectors = new SequenceVectors.Builder<T>(configuration).vocabCache(vocabCache)
                    .lookupTable(lookupTable).resetModel(false).build();

    return vectors;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:51,代码来源:WordVectorSerializer.java


注:本文中的org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.setSyn0方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。