当前位置: 首页>>代码示例>>Java>>正文


Java DataSet.getLabelsMaskArray方法代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.api.DataSet.getLabelsMaskArray方法的典型用法代码示例。如果您正苦于以下问题:Java DataSet.getLabelsMaskArray方法的具体用法?Java DataSet.getLabelsMaskArray怎么用?Java DataSet.getLabelsMaskArray使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.dataset.api.DataSet的用法示例。


在下文中一共展示了DataSet.getLabelsMaskArray方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: main

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: sentimentモデル名
 * args[2] input: test親フォルダ名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);

  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
  Evaluation evaluation = new Evaluation();
  while(test.hasNext()) {
    DataSet t = test.next();
    INDArray features = t.getFeatures();
    INDArray lables = t.getLabels();
    INDArray inMask = t.getFeaturesMaskArray();
    INDArray outMask = t.getLabelsMaskArray();
    INDArray predicted = model.output(features,false,inMask,outMask);
    evaluation.evalTimeSeries(lables,predicted,outMask);
  }
  System.out.println(evaluation.stats());
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:30,代码来源:SentimentRecurrentTestCmd.java

示例2: fit

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
 * Fit the ComputationGraph using a DataSet.
 * Note that this method can only be used with ComputationGraphs with 1 input and 1 output.
 * For networks with more than one input or output, use {@link #fit(MultiDataSetIterator)}
 */
public void fit(DataSet dataSet) {
    if (numInputArrays != 1 || numOutputArrays != 1)
        throw new UnsupportedOperationException("Cannot train ComputationGraph network with "
                + " multiple inputs or outputs using a DataSet");

    boolean hasMaskArrays = dataSet.hasMaskArrays();
    if (hasMaskArrays) {
        INDArray[] fMask = (dataSet.getFeaturesMaskArray() != null ? new INDArray[]{dataSet.getFeaturesMaskArray()}
                : null);
        INDArray[] lMask = (dataSet.getLabelsMaskArray() != null ? new INDArray[]{dataSet.getLabelsMaskArray()}
                : null);
        fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, fMask, lMask);
    } else {
        fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()});
    }

    if (hasMaskArrays)
        clearLayerMaskArrays();

    clearLayersStates();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:ComputationGraph.java

示例3: call

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
@Override
public void call(DataSet dataSet) {
    if (dataSet != null) {
        if (dataSet.getFeatures() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getFeatures(), AffinityManager.Location.DEVICE);

        if (dataSet.getLabels() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getLabels(), AffinityManager.Location.DEVICE);

        if (dataSet.getFeaturesMaskArray() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getFeaturesMaskArray(),
                            AffinityManager.Location.DEVICE);

        if (dataSet.getLabelsMaskArray() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getLabelsMaskArray(), AffinityManager.Location.DEVICE);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:DefaultCallback.java

示例4: main

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: 学習モデル名
 * args[2] input: train/test親フォルダ名
 * args[3] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 1;

  System.out.println("Starting online training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    System.out.println("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    System.out.println(evaluation.stats());

    System.out.println("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:47,代码来源:SentimentRecurrentTrainOnlineCmd.java

示例5: preProcess

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray label = toPreProcess.getLabels();
    INDArray labelMask = toPreProcess.getLabelsMaskArray();
    INDArray sampledMask = adjustMasks(label, labelMask, minorityLabel, targetMinorityDist);
    toPreProcess.setLabelsMaskArray(sampledMask);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:8,代码来源:UnderSamplingByMaskingPreProcessor.java

示例6: toMultiDataSet

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/** Convert a DataSet to the equivalent MultiDataSet */
public static MultiDataSet toMultiDataSet(DataSet dataSet) {
    INDArray f = dataSet.getFeatures();
    INDArray l = dataSet.getLabels();
    INDArray fMask = dataSet.getFeaturesMaskArray();
    INDArray lMask = dataSet.getLabelsMaskArray();

    INDArray[] fNew = f == null ? null : new INDArray[] {f};
    INDArray[] lNew = l == null ? null : new INDArray[] {l};
    INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
    INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);

    return new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:15,代码来源:ComputationGraphUtil.java

示例7: migrate

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
protected void migrate(DataSet ds) {
    if (ds.getFeatures() != null && ds.getFeatures().isAttached())
        ds.setFeatures(ds.getFeatures().migrate());

    if (ds.getLabels() != null && ds.getLabels().isAttached())
        ds.setLabels(ds.getLabels().migrate());

    if (ds.getFeaturesMaskArray() != null && ds.getFeaturesMaskArray().isAttached())
        ds.setFeaturesMaskArray(ds.getFeaturesMaskArray().migrate());

    if (ds.getLabelsMaskArray() != null && ds.getLabelsMaskArray().isAttached())
        ds.setLabelsMaskArray(ds.getLabelsMaskArray().migrate());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:14,代码来源:ComputationGraph.java

示例8: main

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: train/test親フォルダ名
 * args[2] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  int numInputs   = wvec.lookupTable().layerSize();
  int numOutputs  = 2; // FIXME positive or negative
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 5000;
  int listenfreq  = 10;

  MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
      .seed(7485)
      .updater(Updater.RMSPROP) //ADADELTA
      .learningRate(0.001) //RMSPROP
      .rmsDecay(0.90) //RMSPROP
      //.rho(0.95) //ADADELTA
      .epsilon(1e-8) //ALL
      .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
      .weightInit(WeightInit.XAVIER)
      .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
      .gradientNormalizationThreshold(1.0)
      //.regularization(true)
      //.l2(1e-5)
      .list()
      .layer(0, new GravesLSTM.Builder()
          .nIn(numInputs).nOut(numInputs)
          .activation("softsign")
          .build())
      .layer(1, new RnnOutputLayer.Builder()
          .lossFunction(LossFunctions.LossFunction.MCXENT)
          .activation("softmax")
          .nIn(numInputs).nOut(numOutputs)
          .build())
      .pretrain(false).backprop(true).build();

  MultiLayerNetwork model = new MultiLayerNetwork(conf);
  model.init();
  model.setListeners(new ScoreIterationListener(listenfreq));


  LOG.info("Starting training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    LOG.info("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    LOG.info(evaluation.stats());

    LOG.info("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[2]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:78,代码来源:SentimentRecurrentTrainCmd.java


注:本文中的org.nd4j.linalg.dataset.api.DataSet.getLabelsMaskArray方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。