本文整理汇总了Java中org.nd4j.linalg.dataset.api.DataSet.getFeaturesMaskArray方法的典型用法代码示例。如果您正苦于以下问题:Java DataSet.getFeaturesMaskArray方法的具体用法?Java DataSet.getFeaturesMaskArray怎么用?Java DataSet.getFeaturesMaskArray使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.dataset.api.DataSet
的用法示例。
在下文中一共展示了DataSet.getFeaturesMaskArray方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: main
import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
* args[0] input: word2vecファイル名
* args[1] input: sentimentモデル名
* args[2] input: test親フォルダ名
*
* @param args
* @throws Exception
*/
public static void main (final String[] args) throws Exception {
if (args[0]==null || args[1]==null || args[2]==null)
System.exit(1);
WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);
DataSetIterator test = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
Evaluation evaluation = new Evaluation();
while(test.hasNext()) {
DataSet t = test.next();
INDArray features = t.getFeatures();
INDArray lables = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = model.output(features,false,inMask,outMask);
evaluation.evalTimeSeries(lables,predicted,outMask);
}
System.out.println(evaluation.stats());
}
示例2: fit
import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
* Fit the ComputationGraph using a DataSet.
* Note that this method can only be used with ComputationGraphs with 1 input and 1 output.
* For networks with more than one input or output, use {@link #fit(MultiDataSetIterator)}
*/
public void fit(DataSet dataSet) {
if (numInputArrays != 1 || numOutputArrays != 1)
throw new UnsupportedOperationException("Cannot train ComputationGraph network with "
+ " multiple inputs or outputs using a DataSet");
boolean hasMaskArrays = dataSet.hasMaskArrays();
if (hasMaskArrays) {
INDArray[] fMask = (dataSet.getFeaturesMaskArray() != null ? new INDArray[]{dataSet.getFeaturesMaskArray()}
: null);
INDArray[] lMask = (dataSet.getLabelsMaskArray() != null ? new INDArray[]{dataSet.getLabelsMaskArray()}
: null);
fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, fMask, lMask);
} else {
fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()});
}
if (hasMaskArrays)
clearLayerMaskArrays();
clearLayersStates();
}
示例3: call
import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
@Override
public void call(DataSet dataSet) {
if (dataSet != null) {
if (dataSet.getFeatures() != null)
Nd4j.getAffinityManager().ensureLocation(dataSet.getFeatures(), AffinityManager.Location.DEVICE);
if (dataSet.getLabels() != null)
Nd4j.getAffinityManager().ensureLocation(dataSet.getLabels(), AffinityManager.Location.DEVICE);
if (dataSet.getFeaturesMaskArray() != null)
Nd4j.getAffinityManager().ensureLocation(dataSet.getFeaturesMaskArray(),
AffinityManager.Location.DEVICE);
if (dataSet.getLabelsMaskArray() != null)
Nd4j.getAffinityManager().ensureLocation(dataSet.getLabelsMaskArray(), AffinityManager.Location.DEVICE);
}
}
示例4: main
import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
* args[0] input: word2vecファイル名
* args[1] input: 学習モデル名
* args[2] input: train/test親フォルダ名
* args[3] output: 学習モデル名
*
* @param args
* @throws Exception
*/
public static void main (final String[] args) throws Exception {
if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
System.exit(1);
WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
int batchSize = 16;//100;
int testBatch = 64;
int nEpochs = 1;
System.out.println("Starting online training");
DataSetIterator train = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
DataSetIterator test = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
for( int i=0; i<nEpochs; i++ ){
model.fit(train);
train.reset();
System.out.println("Epoch " + i + " complete. Starting evaluation:");
Evaluation evaluation = new Evaluation();
while(test.hasNext()) {
DataSet t = test.next();
INDArray features = t.getFeatures();
INDArray lables = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = model.output(features,false,inMask,outMask);
evaluation.evalTimeSeries(lables,predicted,outMask);
}
test.reset();
System.out.println(evaluation.stats());
System.out.println("Save model");
ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
}
}
示例5: toMultiDataSet
import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/** Convert a DataSet to the equivalent MultiDataSet */
public static MultiDataSet toMultiDataSet(DataSet dataSet) {
INDArray f = dataSet.getFeatures();
INDArray l = dataSet.getLabels();
INDArray fMask = dataSet.getFeaturesMaskArray();
INDArray lMask = dataSet.getLabelsMaskArray();
INDArray[] fNew = f == null ? null : new INDArray[] {f};
INDArray[] lNew = l == null ? null : new INDArray[] {l};
INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);
return new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
}
示例6: migrate
import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
protected void migrate(DataSet ds) {
if (ds.getFeatures() != null && ds.getFeatures().isAttached())
ds.setFeatures(ds.getFeatures().migrate());
if (ds.getLabels() != null && ds.getLabels().isAttached())
ds.setLabels(ds.getLabels().migrate());
if (ds.getFeaturesMaskArray() != null && ds.getFeaturesMaskArray().isAttached())
ds.setFeaturesMaskArray(ds.getFeaturesMaskArray().migrate());
if (ds.getLabelsMaskArray() != null && ds.getLabelsMaskArray().isAttached())
ds.setLabelsMaskArray(ds.getLabelsMaskArray().migrate());
}
示例7: main
import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
* args[0] input: word2vecファイル名
* args[1] input: train/test親フォルダ名
* args[2] output: 学習モデル名
*
* @param args
* @throws Exception
*/
public static void main (final String[] args) throws Exception {
if (args[0]==null || args[1]==null || args[2]==null)
System.exit(1);
WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
int numInputs = wvec.lookupTable().layerSize();
int numOutputs = 2; // FIXME positive or negative
int batchSize = 16;//100;
int testBatch = 64;
int nEpochs = 5000;
int listenfreq = 10;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(7485)
.updater(Updater.RMSPROP) //ADADELTA
.learningRate(0.001) //RMSPROP
.rmsDecay(0.90) //RMSPROP
//.rho(0.95) //ADADELTA
.epsilon(1e-8) //ALL
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0)
//.regularization(true)
//.l2(1e-5)
.list()
.layer(0, new GravesLSTM.Builder()
.nIn(numInputs).nOut(numInputs)
.activation("softsign")
.build())
.layer(1, new RnnOutputLayer.Builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation("softmax")
.nIn(numInputs).nOut(numOutputs)
.build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(listenfreq));
LOG.info("Starting training");
DataSetIterator train = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[1],wvec,batchSize,300,true),2);
DataSetIterator test = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[1],wvec,testBatch,300,false),2);
for( int i=0; i<nEpochs; i++ ){
model.fit(train);
train.reset();
LOG.info("Epoch " + i + " complete. Starting evaluation:");
Evaluation evaluation = new Evaluation();
while(test.hasNext()) {
DataSet t = test.next();
INDArray features = t.getFeatures();
INDArray lables = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = model.output(features,false,inMask,outMask);
evaluation.evalTimeSeries(lables,predicted,outMask);
}
test.reset();
LOG.info(evaluation.stats());
LOG.info("Save model");
ModelSerializer.writeModel(model, new FileOutputStream(args[2]), true);
}
}