本文整理汇总了Java中org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.setSyn0方法的典型用法代码示例。如果您正苦于以下问题:Java InMemoryLookupTable.setSyn0方法的具体用法?Java InMemoryLookupTable.setSyn0怎么用?Java InMemoryLookupTable.setSyn0使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable
的用法示例。
在下文中一共展示了InMemoryLookupTable.setSyn0方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: testParaVecSerialization1
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; //导入方法依赖的package包/类
@Test
public void testParaVecSerialization1() throws Exception {
VectorsConfiguration configuration = new VectorsConfiguration();
configuration.setIterations(14123);
configuration.setLayersSize(156);
INDArray syn0 = Nd4j.rand(100, configuration.getLayersSize());
INDArray syn1 = Nd4j.rand(100, configuration.getLayersSize());
AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
for (int i = 0; i < 100; i++) {
VocabWord word = new VocabWord((float) i, "word_" + i);
List<Integer> points = new ArrayList<>();
List<Byte> codes = new ArrayList<>();
int num = org.apache.commons.lang3.RandomUtils.nextInt(1, 20);
for (int x = 0; x < num; x++) {
points.add(org.apache.commons.lang3.RandomUtils.nextInt(1, 100000));
codes.add(org.apache.commons.lang3.RandomUtils.nextBytes(10)[0]);
}
if (RandomUtils.nextInt(10) < 3) {
word.markAsLabel(true);
}
word.setIndex(i);
word.setPoints(points);
word.setCodes(codes);
cache.addToken(word);
cache.addWordToIndex(i, word.getLabel());
}
InMemoryLookupTable<VocabWord> lookupTable =
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
.vectorLength(configuration.getLayersSize()).cache(cache).build();
lookupTable.setSyn0(syn0);
lookupTable.setSyn1(syn1);
ParagraphVectors originalVectors =
new ParagraphVectors.Builder(configuration).vocabCache(cache).lookupTable(lookupTable).build();
File tempFile = File.createTempFile("paravec", "tests");
tempFile.deleteOnExit();
WordVectorSerializer.writeParagraphVectors(originalVectors, tempFile);
ParagraphVectors restoredVectors = WordVectorSerializer.readParagraphVectors(tempFile);
InMemoryLookupTable<VocabWord> restoredLookupTable =
(InMemoryLookupTable<VocabWord>) restoredVectors.getLookupTable();
AbstractCache<VocabWord> restoredVocab = (AbstractCache<VocabWord>) restoredVectors.getVocab();
assertEquals(restoredLookupTable.getSyn0(), lookupTable.getSyn0());
assertEquals(restoredLookupTable.getSyn1(), lookupTable.getSyn1());
for (int i = 0; i < cache.numWords(); i++) {
assertEquals(cache.elementAtIndex(i).isLabel(), restoredVocab.elementAtIndex(i).isLabel());
assertEquals(cache.wordAtIndex(i), restoredVocab.wordAtIndex(i));
assertEquals(cache.elementAtIndex(i).getElementFrequency(),
restoredVocab.elementAtIndex(i).getElementFrequency(), 0.1f);
List<Integer> originalPoints = cache.elementAtIndex(i).getPoints();
List<Integer> restoredPoints = restoredVocab.elementAtIndex(i).getPoints();
assertEquals(originalPoints.size(), restoredPoints.size());
for (int x = 0; x < originalPoints.size(); x++) {
assertEquals(originalPoints.get(x), restoredPoints.get(x));
}
List<Byte> originalCodes = cache.elementAtIndex(i).getCodes();
List<Byte> restoredCodes = restoredVocab.elementAtIndex(i).getCodes();
assertEquals(originalCodes.size(), restoredCodes.size());
for (int x = 0; x < originalCodes.size(); x++) {
assertEquals(originalCodes.get(x), restoredCodes.get(x));
}
}
}
示例2: loadTxtVectors
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; //导入方法依赖的package包/类
/**
* This method can be used to load previously saved model from InputStream (like a HDFS-stream)
*
* Deprecation note: Please, consider using readWord2VecModel() or loadStaticModel() method instead
*
* @param stream InputStream that contains previously serialized model
* @param skipFirstLine Set this TRUE if first line contains csv header, FALSE otherwise
* @return
* @throws IOException
*/
@Deprecated
public static WordVectors loadTxtVectors(@NonNull InputStream stream, boolean skipFirstLine) throws IOException {
AbstractCache<VocabWord> cache = new AbstractCache.Builder<VocabWord>().build();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
List<INDArray> arrays = new ArrayList<>();
if (skipFirstLine)
reader.readLine();
while ((line = reader.readLine()) != null) {
String[] split = line.split(" ");
String word = split[0].replaceAll(whitespaceReplacement, " ");
VocabWord word1 = new VocabWord(1.0, word);
word1.setIndex(cache.numWords());
cache.addToken(word1);
cache.addWordToIndex(word1.getIndex(), word);
cache.putVocabWord(word);
float[] vector = new float[split.length - 1];
for (int i = 1; i < split.length; i++) {
vector[i - 1] = Float.parseFloat(split[i]);
}
INDArray row = Nd4j.create(vector);
arrays.add(row);
}
InMemoryLookupTable<VocabWord> lookupTable =
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
.vectorLength(arrays.get(0).columns()).cache(cache).build();
INDArray syn = Nd4j.vstack(arrays);
Nd4j.clearNans(syn);
lookupTable.setSyn0(syn);
return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache));
}
示例3: readSequenceVectors
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; //导入方法依赖的package包/类
/**
* This method loads previously saved SequenceVectors model from InputStream
*
* @param factory
* @param stream
* @param <T>
* @return
*/
public static <T extends SequenceElement> SequenceVectors<T> readSequenceVectors(
@NonNull SequenceElementFactory<T> factory, @NonNull InputStream stream) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"));
// at first we load vectors configuration
String line = reader.readLine();
VectorsConfiguration configuration =
VectorsConfiguration.fromJson(new String(Base64.decodeBase64(line), "UTF-8"));
AbstractCache<T> vocabCache = new AbstractCache.Builder<T>().build();
List<INDArray> rows = new ArrayList<>();
while ((line = reader.readLine()) != null) {
if (line.isEmpty()) // skip empty line
continue;
ElementPair pair = ElementPair.fromEncodedJson(line);
T element = factory.deserialize(pair.getObject());
rows.add(Nd4j.create(pair.getVector()));
vocabCache.addToken(element);
vocabCache.addWordToIndex(element.getIndex(), element.getLabel());
}
reader.close();
InMemoryLookupTable<T> lookupTable = (InMemoryLookupTable<T>) new InMemoryLookupTable.Builder<T>()
.vectorLength(rows.get(0).columns()).cache(vocabCache).build(); // fix: add vocab cache
/*
* INDArray syn0 = Nd4j.create(rows.size(), rows.get(0).columns()); for (int x = 0; x < rows.size(); x++) {
* syn0.putRow(x, rows.get(x)); }
*/
INDArray syn0 = Nd4j.vstack(rows);
lookupTable.setSyn0(syn0);
SequenceVectors<T> vectors = new SequenceVectors.Builder<T>(configuration).vocabCache(vocabCache)
.lookupTable(lookupTable).resetModel(false).build();
return vectors;
}