当前位置: 首页>>代码示例>>Java>>正文


Java WordVectorSerializer.loadTxtVectors方法代码示例

本文整理汇总了Java中org.deeplearning4j.models.embeddings.loader.WordVectorSerializer.loadTxtVectors方法的典型用法代码示例。如果您正苦于以下问题:Java WordVectorSerializer.loadTxtVectors方法的具体用法?Java WordVectorSerializer.loadTxtVectors怎么用?Java WordVectorSerializer.loadTxtVectors使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.deeplearning4j.models.embeddings.loader.WordVectorSerializer的用法示例。


在下文中一共展示了WordVectorSerializer.loadTxtVectors方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: testWriteWordVectorsFromWord2Vec

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
@Test
@Ignore
public void testWriteWordVectorsFromWord2Vec() throws IOException {
    WordVectors vec = WordVectorSerializer.loadGoogleModel(binaryFile, true);
    WordVectorSerializer.writeWordVectors((Word2Vec) vec, pathToWriteto);

    WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File(pathToWriteto));
    INDArray wordVector1 = wordVectors.getWordVectorMatrix("Morgan_Freeman");
    INDArray wordVector2 = wordVectors.getWordVectorMatrix("JA_Montalbano");
    assertEquals(vec.getWordVectorMatrix("Morgan_Freeman"), wordVector1);
    assertEquals(vec.getWordVectorMatrix("JA_Montalbano"), wordVector2);
    assertTrue(wordVector1.length() == 300);
    assertTrue(wordVector2.length() == 300);
    assertEquals(wordVector1.getDouble(0), 0.044423, 1e-3);
    assertEquals(wordVector2.getDouble(0), 0.051964, 1e-3);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:WordVectorSerializerTest.java

示例2: loadWordEmbeddings

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
public static void loadWordEmbeddings(String wePath, String weType) {
    try {
        switch (weType) {
            case "Google":
                WordEmbeddingRelatedness.wordVectors = WordVectorSerializer.loadGoogleModel(new File(wePath), true);
                break;
            case "Glove":
                WordEmbeddingRelatedness.wordVectors = WordVectorSerializer.loadTxtVectors(new File(wePath));
                break;
            default:
                System.out.println("Word Embeddings type is invalid! " + weType + " is not a valid type. Please use Google or Glove model.");
                System.exit(0);
        }
    } catch (IOException e) {
        System.out.println("Could not find Word Embeddings file in " + wePath);
    }
}
 
开发者ID:butnaruandrei,项目名称:ShotgunWSD,代码行数:18,代码来源:WordEmbeddingRelatedness.java

示例3: main

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: sentimentモデル名
 * args[2] input: test親フォルダ名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);

  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
  Evaluation evaluation = new Evaluation();
  while(test.hasNext()) {
    DataSet t = test.next();
    INDArray features = t.getFeatures();
    INDArray lables = t.getLabels();
    INDArray inMask = t.getFeaturesMaskArray();
    INDArray outMask = t.getLabelsMaskArray();
    INDArray predicted = model.output(features,false,inMask,outMask);
    evaluation.evalTimeSeries(lables,predicted,outMask);
  }
  System.out.println(evaluation.stats());
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:30,代码来源:SentimentRecurrentTestCmd.java

示例4: testWriteWordVectors

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
@Test
@Ignore
public void testWriteWordVectors() throws IOException {
    WordVectors vec = WordVectorSerializer.loadGoogleModel(binaryFile, true);
    InMemoryLookupTable lookupTable = (InMemoryLookupTable) vec.lookupTable();
    InMemoryLookupCache lookupCache = (InMemoryLookupCache) vec.vocab();
    WordVectorSerializer.writeWordVectors(lookupTable, lookupCache, pathToWriteto);

    WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File(pathToWriteto));
    double[] wordVector1 = wordVectors.getWordVector("Morgan_Freeman");
    double[] wordVector2 = wordVectors.getWordVector("JA_Montalbano");
    assertTrue(wordVector1.length == 300);
    assertTrue(wordVector2.length == 300);
    assertEquals(Doubles.asList(wordVector1).get(0), 0.044423, 1e-3);
    assertEquals(Doubles.asList(wordVector2).get(0), 0.051964, 1e-3);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:WordVectorSerializerTest.java

示例5: testUnifiedLoaderText

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * This method tests CSV file loading via unified loader
 *
 * @throws Exception
 */
@Test
public void testUnifiedLoaderText() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    WordVectors vectorsLive = WordVectorSerializer.loadTxtVectors(textFile);
    WordVectors vectorsUnified = WordVectorSerializer.readWord2VecModel(textFile, true);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
    INDArray arrayStatic = vectorsUnified.getWordVectorMatrix("Morgan_Freeman");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);

    // we're trying EXTENDED model, but file doesn't have syn1/huffman info, so it should be silently degraded to simplified model
    assertEquals(null, ((InMemoryLookupTable) vectorsUnified.lookupTable()).getSyn1());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:WordVectorSerializerTest.java

示例6: testGetMatrix

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
@Test
public void testGetMatrix() throws IOException {
	// Reload target and source vectors
	WordVectors ves = WordVectorSerializer.loadTxtVectors(new File(sourceVector));	 
	//Load source and target training set from dictionary
	System.out.println("Source vector loaded");
	WordVectors ven = WordVectorSerializer.loadTxtVectors(new File(targetVector));
	System.out.println("Target vector loaded");
	VectorTranslation mapper = new VectorTranslation(dictionaryFile, dictionaryLength, columns);
	DoubleMatrix translationMatrix = mapper.calculateTranslationMatrix(ves, ven);
	//Example Spanish -> English
	String[] terms1 = {
			"ser",
			"haber",
			"espacio",
			"mostrar",
			"asesino",
			"intimidad",
			// Hey, I know the numbers, too!
			"dos", "tres", "cuatro", "sesenta",
			"honradez",
			"banquero",
			"medios",
			"deporte",
			"decidido"
	};
	for (String term : terms1) {
		DoubleMatrix vsource = new DoubleMatrix(ves.getWordVector(term));
        double [] vtargetestimated = translationMatrix.mmul(vsource).transpose().toArray();
        mapper.getNMostSimilarByVector(n, term, ven, vtargetestimated);
	}
}
 
开发者ID:josemanuelgp,项目名称:word2vec_vector-translation-java,代码行数:33,代码来源:VectorTranslationTest.java

示例7: main

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: 学習モデル名
 * args[2] input: train/test親フォルダ名
 * args[3] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 1;

  System.out.println("Starting online training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    System.out.println("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    System.out.println(evaluation.stats());

    System.out.println("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:47,代码来源:SentimentRecurrentTrainOnlineCmd.java

示例8: testLoadingWordVectors

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
@Test
public void testLoadingWordVectors() throws Exception {
    File modelFile = new File(pathToWriteto);
    if (!modelFile.exists()) {
        testRunWord2Vec();
    }
    WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(modelFile);
    Collection<String> lst = wordVectors.wordsNearest("day", 10);
    System.out.println(Arrays.toString(lst.toArray()));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:11,代码来源:Word2VecTests.java

示例9: testLoaderStream

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
@Test
public void testLoaderStream() throws IOException {
    WordVectors vec = WordVectorSerializer.loadTxtVectors(new FileInputStream(textFile), true);

    assertEquals(vec.vocab().numWords(), 30);
    assertTrue(vec.vocab().hasToken("Morgan_Freeman"));
    assertTrue(vec.vocab().hasToken("JA_Montalbano"));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:9,代码来源:WordVectorSerializerTest.java

示例10: testLoader

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
@Test
@Ignore
public void testLoader() throws Exception {
    WordVectors vec = WordVectorSerializer.loadTxtVectors(new File("/home/raver119/Downloads/_vectors.txt"));

    logger.info("Rewinding: " + Arrays.toString(vec.getWordVector("rewinding")));
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:8,代码来源:WordVectorSerializerTest.java

示例11: testStaticLoaderText

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * This method tests CSV file loading as static model
 *
 * @throws Exception
 */
@Test
public void testStaticLoaderText() throws Exception {
    logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());

    WordVectors vectorsLive = WordVectorSerializer.loadTxtVectors(textFile);
    WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(textFile);

    INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
    INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");

    assertNotEquals(null, arrayLive);
    assertEquals(arrayLive, arrayStatic);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:WordVectorSerializerTest.java

示例12: testPortugeseW2V

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
@Test
@Ignore
public void testPortugeseW2V() throws Exception {
    WordVectors word2Vec = WordVectorSerializer.loadTxtVectors(new File("/ext/Temp/para.txt"));
    word2Vec.setModelUtils(new FlatModelUtils());

    Collection<String> portu = word2Vec.wordsNearest("carro", 10);
    printWords("carro", portu, word2Vec);

    portu = word2Vec.wordsNearest("davi", 10);
    printWords("davi", portu, word2Vec);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:13,代码来源:Word2VecTest.java

示例13: main

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: train/test親フォルダ名
 * args[2] output: 出力ディレクトリ名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  int numInputs   = wvec.lookupTable().layerSize();
  int numOutputs  = 2; // FIXME positive or negative
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 5000;
  int thresEpochs = 10;
  double minImprovement = 1e-5;
  int listenfreq  = 10;

  MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
      .seed(7485)
      //.updater(Updater.RMSPROP)
      .updater(Updater.ADADELTA)
      //.learningRate(0.001) //RMSPROP
      //.rmsDecay(0.90) //RMSPROP
      .rho(0.95) //ADADELTA
      .epsilon(1e-5) //1e-8 //ALL
      .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
      .weightInit(WeightInit.XAVIER)
      .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
      .gradientNormalizationThreshold(1.0)
      //.regularization(true)
      //.l2(1e-5)
      .list()
      .layer(0, new GravesLSTM.Builder()
          .nIn(numInputs).nOut(numInputs)
          .activation("softsign")
          .build())
      .layer(1, new RnnOutputLayer.Builder()
          .lossFunction(LossFunctions.LossFunction.MCXENT)
          .activation("softmax")
          .nIn(numInputs).nOut(numOutputs)
          .build())
      .pretrain(false).backprop(true).build();

  MultiLayerNetwork model = new MultiLayerNetwork(conf);
  model.setListeners(new ScoreIterationListener(listenfreq));
  //model.setListeners(new HistogramIterationListener(listenfreq)); //FIXME error occur


  LOG.info("Starting training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,testBatch,300,false),2);

  EarlyStoppingModelSaver<MultiLayerNetwork> saver = new LocalFileModelSaver(args[2]);//new InMemoryModelSaver<>();
  EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
      .epochTerminationConditions(
          new MaxEpochsTerminationCondition(nEpochs),
          new ScoreImprovementEpochTerminationCondition(thresEpochs,minImprovement))
      .scoreCalculator(new DataSetLossCalculator(test, true))
      .modelSaver(saver)
      .build();

  IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf,model,train);
  EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
  LOG.info("Termination reason: " + result.getTerminationReason());
  LOG.info("Termination details: " + result.getTerminationDetails());
  LOG.info("Total epochs: " + result.getTotalEpochs());
  LOG.info("Best epoch number: " + result.getBestModelEpoch());
  LOG.info("Score at best epoch: " + result.getBestModelScore());

  //LOG.info("Save model");
  //MultiLayerNetwork best = result.getBestModel();
  //ModelSerializer.writeModel(best, new FileOutputStream(args[2]+"/sentiment.rnn.es.model"), true);

}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:82,代码来源:SentimentRecurrentTrainEarlyStopCmd.java

示例14: main

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: train/test親フォルダ名
 * args[2] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  int numInputs   = wvec.lookupTable().layerSize();
  int numOutputs  = 2; // FIXME positive or negative
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 5000;
  int listenfreq  = 10;

  MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
      .seed(7485)
      .updater(Updater.RMSPROP) //ADADELTA
      .learningRate(0.001) //RMSPROP
      .rmsDecay(0.90) //RMSPROP
      //.rho(0.95) //ADADELTA
      .epsilon(1e-8) //ALL
      .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
      .weightInit(WeightInit.XAVIER)
      .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
      .gradientNormalizationThreshold(1.0)
      //.regularization(true)
      //.l2(1e-5)
      .list()
      .layer(0, new GravesLSTM.Builder()
          .nIn(numInputs).nOut(numInputs)
          .activation("softsign")
          .build())
      .layer(1, new RnnOutputLayer.Builder()
          .lossFunction(LossFunctions.LossFunction.MCXENT)
          .activation("softmax")
          .nIn(numInputs).nOut(numOutputs)
          .build())
      .pretrain(false).backprop(true).build();

  MultiLayerNetwork model = new MultiLayerNetwork(conf);
  model.init();
  model.setListeners(new ScoreIterationListener(listenfreq));


  LOG.info("Starting training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    LOG.info("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    LOG.info(evaluation.stats());

    LOG.info("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[2]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:78,代码来源:SentimentRecurrentTrainCmd.java

示例15: testIndexPersistence

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
@Test
public void testIndexPersistence() throws Exception {
    File inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();
    SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
    // Split on white spaces in the line to get words
    TokenizerFactory t = new DefaultTokenizerFactory();
    t.setTokenPreProcessor(new CommonPreprocessor());

    Word2Vec vec = new Word2Vec.Builder().minWordFrequency(5).iterations(1).epochs(1).layerSize(100)
                    .stopWords(new ArrayList<String>()).useAdaGrad(false).negativeSample(5).seed(42).windowSize(5)
                    .iterate(iter).tokenizerFactory(t).build();

    vec.fit();

    VocabCache orig = vec.getVocab();

    File tempFile = File.createTempFile("temp", "w2v");
    tempFile.deleteOnExit();

    WordVectorSerializer.writeWordVectors(vec, tempFile);

    WordVectors vec2 = WordVectorSerializer.loadTxtVectors(tempFile);

    VocabCache rest = vec2.vocab();

    assertEquals(orig.totalNumberOfDocs(), rest.totalNumberOfDocs());

    for (VocabWord word : vec.getVocab().vocabWords()) {
        INDArray array1 = vec.getWordVectorMatrix(word.getLabel());
        INDArray array2 = vec2.getWordVectorMatrix(word.getLabel());

        assertEquals(array1, array2);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:35,代码来源:WordVectorSerializerTest.java


注:本文中的org.deeplearning4j.models.embeddings.loader.WordVectorSerializer.loadTxtVectors方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。