本文整理汇总了Java中org.deeplearning4j.models.embeddings.loader.WordVectorSerializer.loadStaticModel方法的典型用法代码示例。如果您正苦于以下问题:Java WordVectorSerializer.loadStaticModel方法的具体用法?Java WordVectorSerializer.loadStaticModel怎么用?Java WordVectorSerializer.loadStaticModel使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.deeplearning4j.models.embeddings.loader.WordVectorSerializer
的用法示例。
在下文中一共展示了WordVectorSerializer.loadStaticModel方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: testStaticLoaderArchive
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
* This method tests ZIP file loading as static model
*
* @throws Exception
*/
@Test
public void testStaticLoaderArchive() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
File w2v = new ClassPathResource("word2vec.dl4j/file.w2v").getFile();
WordVectors vectorsLive = WordVectorSerializer.readWord2Vec(w2v);
WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(w2v);
INDArray arrayLive = vectorsLive.getWordVectorMatrix("night");
INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("night");
assertNotEquals(null, arrayLive);
assertEquals(arrayLive, arrayStatic);
}
示例2: initWordVectors
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/** Initialize the word vectors from the given file */
public void initWordVectors() {
if (wordVectors != null) {
log.debug("Word vectors already loaded, skipping initialization.");
return;
}
log.debug("Loading word vector model");
final String path = wordVectorLocation.getAbsolutePath();
final String pathLower = path.toLowerCase();
if (pathLower.endsWith(".arff")) {
loadEmbeddingFromArff(path);
} else if (pathLower.endsWith(".csv")) {
// Check if file is CSV
boolean success = loadEmbeddingFromCSV(wordVectorLocation);
if (!success) {
throw new RuntimeException("Could not load the word vector file.");
}
} else if (pathLower.endsWith(".csv.gz")) {
loadGZipped();
} else {
// If no file extension was caught before, try loading as is
wordVectors = WordVectorSerializer.loadStaticModel(wordVectorLocation);
}
}
示例3: loadGZipped
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/** Load wordVectors from a gzipped csv file */
private void loadGZipped() {
try {
wordVectors = WordVectorSerializer.loadStaticModel(wordVectorLocation);
} catch (RuntimeException re) {
// Dl4j format not found, continue with decompression by hand
try {
GZIPInputStream gzis = new GZIPInputStream(new FileInputStream(wordVectorLocation));
File tmpFile =
Paths.get(System.getProperty("java.io.tmpdir"), "wordmodel-tmp.csv").toFile();
tmpFile.delete();
FileOutputStream fos = new FileOutputStream(tmpFile);
int length;
byte[] buffer = new byte[1024];
while ((length = gzis.read(buffer)) > 0) {
fos.write(buffer, 0, length);
}
fos.close();
gzis.close();
// Try loading decompressed CSV file
boolean success = loadEmbeddingFromCSV(tmpFile);
if (!success) {
throw new RuntimeException("Could not load the word vector file.");
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
示例4: makeData
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
public Instances makeData() throws Exception {
final Instances data = TestUtil.makeTestDataset(42,
100,
0,
0,
1,
0,
0,
1,
Attribute.NUMERIC,
1,
false);
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(DatasetLoader.loadGoogleNewsVectors());
String[] words = (String[]) wordVectors.vocab().words().toArray(new String[0]);
Random rand = new Random(42);
for (Instance inst : data) {
StringBuilder sentence = new StringBuilder();
for(int i = 0; i < 10; i++){
final int idx = rand.nextInt(words.length);
sentence.append(" ").append(words[idx]);
}
inst.setValue(0, sentence.toString());
}
return data;
}
示例5: testStaticLoaderGoogleModel
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
* This method here is only to test real google model few gigabytes worth
* Keep it ignored, since it requirs full google model being present in system, which is 1.6gb compressed
*
* @throws Exception
*/
@Test
@Ignore
public void testStaticLoaderGoogleModel() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
long time1 = System.currentTimeMillis();
WordVectors vectors = WordVectorSerializer
.loadStaticModel(new File("C:\\Users\\raver\\develop\\GoogleNews-vectors-negative300.bin.gz"));
long time2 = System.currentTimeMillis();
logger.info("Loading time: {} ms", (time2 - time1));
}
示例6: testStaticLoaderBinary
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
* This method tests binary file loading as static model
*
* @throws Exception
*/
@Test
public void testStaticLoaderBinary() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
WordVectors vectorsLive = WordVectorSerializer.loadGoogleModel(binaryFile, true);
WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(binaryFile);
INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");
assertNotEquals(null, arrayLive);
assertEquals(arrayLive, arrayStatic);
}
示例7: testStaticLoaderText
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
/**
* This method tests CSV file loading as static model
*
* @throws Exception
*/
@Test
public void testStaticLoaderText() throws Exception {
logger.info("Executor name: {}", Nd4j.getExecutioner().getClass().getSimpleName());
WordVectors vectorsLive = WordVectorSerializer.loadTxtVectors(textFile);
WordVectors vectorsStatic = WordVectorSerializer.loadStaticModel(textFile);
INDArray arrayLive = vectorsLive.getWordVectorMatrix("Morgan_Freeman");
INDArray arrayStatic = vectorsStatic.getWordVectorMatrix("Morgan_Freeman");
assertNotEquals(null, arrayLive);
assertEquals(arrayLive, arrayStatic);
}
示例8: main
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; //导入方法依赖的package包/类
public static void main (String[] args) throws IOException {
log.info("download and extract data...");
CNNSentenceClassification.aclImdbDownloader(DATA_URL, DATA_PATH);
// basic configuration
int batchSize = 32;
int vectorSize = 300; //Size of the word vectors. 300 in the Google News model
int nEpochs = 1; //Number of epochs (full passes of training data) to train on
int truncateReviewsToLength = 256; //Truncate reviews with length (# words) greater than this
int cnnLayerFeatureMaps = 100; //Number of feature maps / channels / depth for each CNN layer
PoolingType globalPoolingType = PoolingType.MAX;
Random rng = new Random(12345); //For shuffling repeatability
log.info("construct cnn model...");
ComputationGraph net = CNNSentenceClassification.buildCNNGraph(vectorSize, cnnLayerFeatureMaps, globalPoolingType);
log.info("number of parameters by layer:");
for (Layer l : net.getLayers()) {
log.info("\t" + l.conf().getLayer().getLayerName() + "\t" + l.numParams());
}
// Load word vectors and get the DataSetIterators for training and testing
log.info("loading word vectors and creating DataSetIterators...");
WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(WORD_VECTORS_PATH));
DataSetIterator trainIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, true, wordVectors, batchSize,
truncateReviewsToLength, rng);
DataSetIterator testIter = CNNSentenceClassification.getDataSetIterator(DATA_PATH, false, wordVectors, batchSize,
truncateReviewsToLength, rng);
log.info("starting training...");
for (int i = 0; i < nEpochs; i++) {
net.fit(trainIter);
log.info("Epoch " + i + " complete. Starting evaluation:");
//Run evaluation. This is on 25k reviews, so can take some time
Evaluation evaluation = net.evaluate(testIter);
log.info(evaluation.stats());
}
// after training: load a single sentence and generate a prediction
String pathFirstNegativeFile = FilenameUtils.concat(DATA_PATH, "aclImdb/test/neg/0_2.txt");
String contentsFirstNegative = FileUtils.readFileToString(new File(pathFirstNegativeFile));
INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator)testIter).loadSingleSentence(contentsFirstNegative);
INDArray predictionsFirstNegative = net.outputSingle(featuresFirstNegative);
List<String> labels = testIter.getLabels();
log.info("\n\nPredictions for first negative review:");
for( int i=0; i<labels.size(); i++ ){
log.info("P(" + labels.get(i) + ") = " + predictionsFirstNegative.getDouble(i));
}
}