当前位置: 首页>>代码示例>>Java>>正文


Java WordVectorSerializer.loadStaticModel方法代码示例

本文整理汇总了Java中org.deeplearning4j.models.embeddings.loader.WordVectorSerializer.loadStaticModel方法的典型用法代码示例。如果您正苦于以下问题:Java WordVectorSerializer.loadStaticModel方法的具体用法?Java WordVectorSerializer.loadStaticModel怎么用?Java WordVectorSerializer.loadStaticModel使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.models.embeddings.loader.WordVectorSerializer的用法示例。


在下文中一共展示了WordVectorSerializer.loadStaticModel方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testStaticLoaderArchive

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * This method tests ZIP file loading as static model
 *
 * @throws Exception
 */
@Test
public void testStaticLoaderArchive() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();

    WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
    WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(w2v);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
    INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("night");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:WordVectorSerializerTest.java

示例2: initWordVectors

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/** Initialize the word vectors from the given file */
public void initWordVectors() {

  if (wordVectors != null) {
    log.debug("Word vectors already loaded, skipping initialization.");
    return;
  }

  log.debug("Loading word vector model");

  final String path = wordVectorLocation.getAbsolutePath();
  final String pathLower = path.toLowerCase();
  if (pathLower.endsWith(".arff")) {
    loadEmbeddingFromArff(path);
  } else if (pathLower.endsWith(".csv")) {
    // Check if file is CSV
    boolean success = loadEmbeddingFromCSV(wordVectorLocation);
    if (!success) {
      throw new RuntimeException("Could not load the word vector file.");
    }
  } else if (pathLower.endsWith(".csv.gz")) {
    loadGZipped();
  } else {
    // If no file extension was caught before, try loading as is
    wordVectors = WordVectorSerializer.loadStaticModel(wordVectorLocation);
  }
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:28,代码来源:AbstractTextEmbeddingIterator.java

示例3: loadGZipped

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/** Load wordVectors from a gzipped csv file */
private void loadGZipped() {
  try {
    wordVectors = WordVectorSerializer.loadStaticModel(wordVectorLocation);
  } catch (RuntimeException re) {
    // Dl4j format not found, continue with decompression by hand
    try {
      GZIPInputStream gzis = new GZIPInputStream(new FileInputStream(wordVectorLocation));
      File tmpFile =
          Paths.get(System.getProperty("java.io.tmpdir"), "wordmodel-tmp.csv").toFile();
      tmpFile.delete();
      FileOutputStream fos = new FileOutputStream(tmpFile);
      int length;
      byte[] buffer = new byte[1024];
      while ((length = gzis.read(buffer)) > 0) {
        fos.write(buffer, 0, length);
      }
      fos.close();
      gzis.close();

      // Try loading decompressed CSV file
      boolean success = loadEmbeddingFromCSV(tmpFile);
      if (!success) {
        throw new RuntimeException("Could not load the word vector file.");
      }
    } catch (IOException e) {
      e.printStackTrace();
    }
  }
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:31,代码来源:AbstractTextEmbeddingIterator.java

示例4: makeData

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
public Instances makeData() throws Exception {
  final Instances data = TestUtil.makeTestDataset(42,
      100,
      0,
      0,
      1,
      0,
      0,
      1,
      Attribute.NUMERIC,
      1,
      false);

  WordVectors wordVectors = WordVectorSerializer.loadStaticModel(DatasetLoader.loadGoogleNewsVectors());
  String[] words = (String[]) wordVectors.vocab().words().toArray(new String[0]);

  Random rand = new Random(42);
  for (Instance inst : data) {
    StringBuilder sentence = new StringBuilder();
    for(int i = 0; i < 10; i++){
      final int idx = rand.nextInt(words.length);
      sentence.append(" ").append(words[idx]);
    }
    inst.setValue(0, sentence.toString());
  }
  return data;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:28,代码来源:CnnTextFilesEmbeddingInstanceIteratorTest.java

示例5: testStaticLoaderGoogleModel

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * This method here is only to test real google model few gigabytes worth
 * Keep it ignored, since it requirs full google model being present in system, which is 1.6gb compressed
 *
 * @throws Exception
 */
@Test
@Ignore
public void testStaticLoaderGoogleModel() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    long time1 = System.currentTimeMillis();
    WordVectors vectors = WordVectorSerializer
                    .loadStaticModel(new File("C:\\Users\\raver\\develop\\GoogleNews-vectors-negative300.bin.gz"));
    long time2 = System.currentTimeMillis();

    logger.info("Loading time: {} ms", (time2 - time1));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:WordVectorSerializerTest.java

示例6: testStaticLoaderBinary

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * This method tests binary file loading as static model
 *
 * @throws Exception
 */
@Test
public void testStaticLoaderBinary() throws Exception {

    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    WordVectors vectorsLive = WordVectorSerializer.loadGoogleModel(binaryFile, true);
    WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(binaryFile);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
    INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:WordVectorSerializerTest.java

示例7: testStaticLoaderText

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * This method tests CSV file loading as static model
 *
 * @throws Exception
 */
@Test
public void testStaticLoaderText() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    WordVectors vectorsLive = WordVectorSerializer.loadTxtVectors(textFile);
    WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(textFile);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
    INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:WordVectorSerializerTest.java

示例8: main

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
public static void main (String[] args) throws IOException {
    log.info("download and extract data...");
    CNNSentenceClassification.aclImdbDownloader(DATA_URL, DATA_PATH);

    // basic configuration
    int batchSize = 32;
    int vectorSize = 300;               //Size of the word vectors. 300 in the Google News model
    int nEpochs = 1;                    //Number of epochs (full passes of training data) to train on
    int truncateReviewsToLength = 256;  //Truncate reviews with length (# words) greater than this
    int cnnLayerFeatureMaps = 100;      //Number of feature maps / channels / depth for each CNN layer
    PoolingType globalPoolingType = PoolingType.MAX;
    Random rng = new Random(12345); //For shuffling repeatability

    log.info("construct cnn model...");
    ComputationGraph net = CNNSentenceClassification.buildCNNGraph(vectorSize, cnnLayerFeatureMaps, globalPoolingType);
    log.info("number of parameters by layer:");
    for (Layer l : net.getLayers()) {
        log.info("\t" + l.conf().getLayer().getLayerName() + "\t" + l.numParams());
    }

    // Load word vectors and get the DataSetIterators for training and testing
    log.info("loading word vectors and creating DataSetIterators...");
    WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
    DataSetIterator trainIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, true, wordVectors, batchSize,
            truncateReviewsToLength, rng);
    DataSetIterator testIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, false, wordVectors, batchSize,
            truncateReviewsToLength, rng);

    log.info("starting training...");
    for (int i = 0; i < nEpochs; i++) {
        net.fit(trainIter);
        log.info("Epoch " + i + " complete. Starting evaluation:");
        //Run evaluation. This is on 25k reviews, so can take some time
        Evaluation evaluation = net.evaluate(testIter);
        log.info(evaluation.stats());
    }

    // after training: load a single sentence and generate a prediction
    String pathFirstNegativeFile = FilenameUtils.concat(DATA_PATH, "aclImdb/test/neg/0_2.txt");
    String contentsFirstNegative = FileUtils.readFileToString(new File(pathFirstNegativeFile));
    INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)testIter).loadSingleSentence(contentsFirstNegative);
    INDArray predictionsFirstNegative = net.outputSingle(featuresFirstNegative);
    List<String> labels = testIter.getLabels();
    log.info("\n\nPredictions for first negative review:");
    for( int i=0; i<labels.size(); i++ ){
        log.info("P(" + labels.get(i) + ") = " + predictionsFirstNegative.getDouble(i));
    }
}
 
开发者ID:IsaacChanghau,项目名称:Word2VecfJava,代码行数:49,代码来源:DL4JCNNSentClassifyExample.java


注:本文中的org.deeplearning4j.models.embeddings.loader.WordVectorSerializer.loadStaticModel方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。