当前位置: 首页>>代码示例>>Java>>正文


Java SkipGram类代码示例

本文整理汇总了Java中org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram的典型用法代码示例。如果您正苦于以下问题:Java SkipGram类的具体用法?Java SkipGram怎么用?Java SkipGram使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


SkipGram类属于org.deeplearning4j.models.embeddings.learning.impl.elements包,在下文中一共展示了SkipGram类的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: w2vBuilder

import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; //导入依赖的package包/类
public static Word2Vec w2vBuilder(SentenceIterator iter, TokenizerFactory t) {
	return new Word2Vec.Builder()
			.seed(12345)
			.iterate(iter)
			.tokenizerFactory(t)
			.batchSize(1000)
			.allowParallelTokenization(true) // enable parallel tokenization
			.epochs(1) //  number of epochs (iterations over whole training corpus) for training
			.iterations(3) // number of iterations done for each mini-batch during training
			.elementsLearningAlgorithm(new SkipGram<>()) // use SkipGram Model. If CBOW: new CBOW<>()
			.minWordFrequency(50) // discard words that appear less than the times of set value
			.windowSize(5) // set max skip length between words
			.learningRate(0.05) // the starting learning rate
			.minLearningRate(5e-4) // learning rate should not lower than the set threshold value
			.negativeSample(10) // number of negative examples
			// set threshold for occurrence of words. Those that appear with higher frequency will be
			// randomly down-sampled
			.sampling(1e-5)
			.useHierarchicSoftmax(true) // use hierarchical softmax
			.layerSize(300) // size of word vectors
			.workers(8) // number of threads
			.build();
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:24,代码来源:Word2VecTrainer.java

示例2: shouldLoadAndCreateSameWord2Vec

import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; //导入依赖的package包/类
@Test
public void shouldLoadAndCreateSameWord2Vec() {
    //given
    Pan15Parser parser = new Pan15Parser();
    HashMap<String, Pan15Author> english = parser.parseCSVCorpus().get(Language.ENGLISH);
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());
    List<String> englishSentences = english.values().stream().map(Author::getDocuments)
            .collect(Collectors.toList())
            .stream().flatMap(List::stream).collect(Collectors.toList());

    SentenceIterator englishIter = new CollectionSentenceIterator(new Pan15SentencePreProcessor(), englishSentences);
    // when
    Word2Vec englishVec = new Word2Vec.Builder()
            .minWordFrequency(6)
            .iterations(15)
            .layerSize(250)
            .seed(42)
            .windowSize(5)
            .iterate(englishIter)
            .tokenizerFactory(t)
            .build();

    englishVec.fit();
    Word2Vec loadedEnglishVec1 = new Pan15Word2Vec(new SkipGram<>()).readModelFromFile(Language.ENGLISH);
    Word2Vec loadedEnglishVec2 = new Pan15Word2Vec(new CBOW<>()).readModelFromFile(Language.ENGLISH);
    Word2Vec loadedEnglishVec3 = new Pan15Word2Vec(new GloVe<>()).readModelFromFile(Language.ENGLISH);
    loadedEnglishVec1.setTokenizerFactory(t);
    loadedEnglishVec1.setSentenceIterator(englishIter);
    loadedEnglishVec2.setTokenizerFactory(t);
    loadedEnglishVec2.setSentenceIterator(englishIter);
    loadedEnglishVec3.setTokenizerFactory(t);
    loadedEnglishVec3.setSentenceIterator(englishIter);

    //then
    Assert.assertNotNull(loadedEnglishVec1);
    System.out.println(englishVec.wordsNearest("home", 15));
    System.out.println(loadedEnglishVec1.wordsNearest("home", 15));
    System.out.println(loadedEnglishVec2.wordsNearest("home", 15));
    System.out.println(loadedEnglishVec3.wordsNearest("home", 15));
}
 
开发者ID:madeleine789,项目名称:dl4j-apr,代码行数:42,代码来源:Pan15Word2VecTest.java

示例3: testRunWord2Vec

import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; //导入依赖的package包/类
@Test
public void testRunWord2Vec() throws Exception {
    // Strip white space before and after for each line
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());


    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100)
                    .stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001)
                    .sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>())
                    //.negativeSample(10)
                    .epochs(1).windowSize(5).allowParallelTokenization(true)
                    .modelUtils(new BasicModelUtils<VocabWord>()).iterate(iter).tokenizerFactory(t).build();

    assertEquals(new ArrayList<String>(), vec.getStopWords());
    vec.fit();
    File tempFile = File.createTempFile("temp", "temp");
    tempFile.deleteOnExit();

    WordVectorSerializer.writeFullModel(vec, tempFile.getAbsolutePath());
    Collection<String> lst = vec.wordsNearest("day", 10);
    //log.info(Arrays.toString(lst.toArray()));
    printWords("day", lst, vec);

    assertEquals(10, lst.size());

    double sim = vec.similarity("day", "night");
    log.info("Day/night similarity: " + sim);

    assertTrue(sim < 1.0);
    assertTrue(sim > 0.4);


    assertTrue(lst.contains("week"));
    assertTrue(lst.contains("night"));
    assertTrue(lst.contains("year"));

    assertFalse(lst.contains(null));


    lst = vec.wordsNearest("day", 10);
    //log.info(Arrays.toString(lst.toArray()));
    printWords("day", lst, vec);

    assertTrue(lst.contains("week"));
    assertTrue(lst.contains("night"));
    assertTrue(lst.contains("year"));

    new File("cache.ser").delete();

    ArrayList<String> labels = new ArrayList<>();
    labels.add("day");
    labels.add("night");
    labels.add("week");

    INDArray matrix = vec.getWordVectors(labels);
    assertEquals(matrix.getRow(0), vec.getWordVectorMatrix("day"));
    assertEquals(matrix.getRow(1), vec.getWordVectorMatrix("night"));
    assertEquals(matrix.getRow(2), vec.getWordVectorMatrix("week"));

    WordVectorSerializer.writeWordVectors(vec, pathToWriteto);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:65,代码来源:Word2VecTests.java

示例4: testW2VnegativeOnRestore

import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; //导入依赖的package包/类
@Test
public void testW2VnegativeOnRestore() throws Exception {
    // Strip white space before and after for each line
    SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());


    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(1).iterations(3).batchSize(64).layerSize(100)
                    .stopWords(new ArrayList<String>()).seed(42).learningRate(0.025).minLearningRate(0.001)
                    .sampling(0).elementsLearningAlgorithm(new SkipGram<VocabWord>()).negativeSample(10).epochs(1)
                    .windowSize(5).useHierarchicSoftmax(false).allowParallelTokenization(true)
                    .modelUtils(new FlatModelUtils<VocabWord>()).iterate(iter).tokenizerFactory(t).build();


    assertEquals(false, vec.getConfiguration().isUseHierarchicSoftmax());

    log.info("Fit 1");
    vec.fit();

    File tmpFile = File.createTempFile("temp", "file");
    tmpFile.deleteOnExit();

    WordVectorSerializer.writeWord2VecModel(vec, tmpFile);

    iter.reset();

    Word2Vec restoredVec = WordVectorSerializer.readWord2VecModel(tmpFile, true);
    restoredVec.setTokenizerFactory(t);
    restoredVec.setSentenceIterator(iter);

    assertEquals(false, restoredVec.getConfiguration().isUseHierarchicSoftmax());
    assertTrue(restoredVec.getModelUtils() instanceof FlatModelUtils);
    assertTrue(restoredVec.getConfiguration().isAllowParallelTokenization());

    log.info("Fit 2");
    restoredVec.fit();


    iter.reset();
    restoredVec = WordVectorSerializer.readWord2VecModel(tmpFile, false);
    restoredVec.setTokenizerFactory(t);
    restoredVec.setSentenceIterator(iter);

    assertEquals(false, restoredVec.getConfiguration().isUseHierarchicSoftmax());
    assertTrue(restoredVec.getModelUtils() instanceof BasicModelUtils);

    log.info("Fit 3");
    restoredVec.fit();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:52,代码来源:Word2VecTests.java

示例5: testDirectInference

import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; //导入依赖的package包/类
@Test
public void testDirectInference() throws Exception {
    ClassPathResource resource_sentences = new ClassPathResource("/big/raw_sentences.txt");
    ClassPathResource resource_mixed = new ClassPathResource("/paravec");
    SentenceIterator iter = new AggregatingSentenceIterator.Builder()
                    .addSentenceIterator(new BasicLineIterator(resource_sentences.getFile()))
                    .addSentenceIterator(new FileSentenceIterator(resource_mixed.getFile())).build();

    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());

    Word2Vec wordVectors = new Word2Vec.Builder().minWordFrequency(1).batchSize(250).iterations(1).epochs(3)
                    .learningRate(0.025).layerSize(150).minLearningRate(0.001)
                    .elementsLearningAlgorithm(new SkipGram<VocabWord>()).useHierarchicSoftmax(true).windowSize(5)
                    .iterate(iter).tokenizerFactory(t).build();

    wordVectors.fit();

    ParagraphVectors pv = new ParagraphVectors.Builder().tokenizerFactory(t).iterations(10)
                    .useHierarchicSoftmax(true).trainWordVectors(true).useExistingWordVectors(wordVectors)
                    .negativeSample(0).sequenceLearningAlgorithm(new DM<VocabWord>()).build();

    INDArray vec1 = pv.inferVector("This text is pretty awesome");
    INDArray vec2 = pv.inferVector("Fantastic process of crazy things happening inside just for history purposes");

    log.info("vec1/vec2: {}", Transforms.cosineSim(vec1, vec2));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:28,代码来源:ParagraphVectorsTest.java

示例6: readModelFromFile

import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; //导入依赖的package包/类
public  Word2Vec readModelFromFile(Language language) {
    String path = (learningAlgorithm instanceof SkipGram) ?
            language.getName() + "_model.txt" : language.getName() + "_model_" + learningAlgorithm.getCodeName() + ".txt";
    URL resource = Pan15Word2Vec.class.getClassLoader()
            .getResource("word2vec/" + path);
    try {
        return WordVectorSerializer.readWord2VecModel(Paths.get(resource.toURI()).toFile());
    } catch (URISyntaxException e) {
        e.printStackTrace();
    }
    return null;
}
 
开发者ID:madeleine789,项目名称:dl4j-apr,代码行数:13,代码来源:Pan15Word2Vec.java

示例7: saveModel

import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; //导入依赖的package包/类
public void saveModel(Word2Vec model, Language language) {
    String dir = "./src/main/resources/word2vec";
    String path = (learningAlgorithm instanceof SkipGram) ?
            dir + "/" + language.getName() + "_model.txt"
            :dir + "/" + language.getName() + "_model_" + learningAlgorithm.getCodeName() + ".txt";
    WordVectorSerializer.writeWord2VecModel(model, path);
}
 
开发者ID:madeleine789,项目名称:dl4j-apr,代码行数:8,代码来源:Pan15Word2Vec.java


注:本文中的org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。