当前位置: 首页>>代码示例>>Java>>正文


Java DataSet.getFeatures方法代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.api.DataSet.getFeatures方法的典型用法代码示例。如果您正苦于以下问题:Java DataSet.getFeatures方法的具体用法?Java DataSet.getFeatures怎么用?Java DataSet.getFeatures使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.dataset.api.DataSet的用法示例。


在下文中一共展示了DataSet.getFeatures方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: main

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: sentimentモデル名
 * args[2] input: test親フォルダ名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],false);

  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,100,300,false),1);
  Evaluation evaluation = new Evaluation();
  while(test.hasNext()) {
    DataSet t = test.next();
    INDArray features = t.getFeatures();
    INDArray lables = t.getLabels();
    INDArray inMask = t.getFeaturesMaskArray();
    INDArray outMask = t.getLabelsMaskArray();
    INDArray predicted = model.output(features,false,inMask,outMask);
    evaluation.evalTimeSeries(lables,predicted,outMask);
  }
  System.out.println(evaluation.stats());
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:30,代码来源:SentimentRecurrentTestCmd.java

示例2: preProcess

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray input = toPreProcess.getFeatures();
    if (input.rank() == 2)
        return; //No op: should usually never happen in a properly configured data pipeline

    //Assume input is standard rank 4 activations - i.e., CNN image data
    //First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct
    if (input.ordering() != 'c' || !Shape.strideDescendingCAscendingF(input))
        input = input.dup('c');

    int[] inShape = input.shape(); //[miniBatch,depthOut,outH,outW]
    int[] outShape = new int[] {inShape[0], inShape[1] * inShape[2] * inShape[3]};

    INDArray reshaped = input.reshape('c', outShape);
    toPreProcess.setFeatures(reshaped);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:18,代码来源:ImageFlatteningDataSetPreProcessor.java

示例3: call

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
@Override
public void call(DataSet dataSet) {
    if (dataSet != null) {
        if (dataSet.getFeatures() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getFeatures(), AffinityManager.Location.DEVICE);

        if (dataSet.getLabels() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getLabels(), AffinityManager.Location.DEVICE);

        if (dataSet.getFeaturesMaskArray() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getFeaturesMaskArray(),
                            AffinityManager.Location.DEVICE);

        if (dataSet.getLabelsMaskArray() != null)
            Nd4j.getAffinityManager().ensureLocation(dataSet.getLabelsMaskArray(), AffinityManager.Location.DEVICE);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:DefaultCallback.java

示例4: main

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: 学習モデル名
 * args[2] input: train/test親フォルダ名
 * args[3] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null || args[3]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(args[1],true);
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 1;

  System.out.println("Starting online training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[2],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    System.out.println("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    System.out.println(evaluation.stats());

    System.out.println("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[3]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:47,代码来源:SentimentRecurrentTrainOnlineCmd.java

示例5: toMultiDataSet

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/** Convert a DataSet to the equivalent MultiDataSet */
public static MultiDataSet toMultiDataSet(DataSet dataSet) {
    INDArray f = dataSet.getFeatures();
    INDArray l = dataSet.getLabels();
    INDArray fMask = dataSet.getFeaturesMaskArray();
    INDArray lMask = dataSet.getLabelsMaskArray();

    INDArray[] fNew = f == null ? null : new INDArray[] {f};
    INDArray[] lNew = l == null ? null : new INDArray[] {l};
    INDArray[] fMaskNew = (fMask != null ? new INDArray[] {fMask} : null);
    INDArray[] lMaskNew = (lMask != null ? new INDArray[] {lMask} : null);

    return new org.nd4j.linalg.dataset.MultiDataSet(fNew, lNew, fMaskNew, lMaskNew);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:15,代码来源:ComputationGraphUtil.java

示例6: migrate

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
protected void migrate(DataSet ds) {
    if (ds.getFeatures() != null && ds.getFeatures().isAttached())
        ds.setFeatures(ds.getFeatures().migrate());

    if (ds.getLabels() != null && ds.getLabels().isAttached())
        ds.setLabels(ds.getLabels().migrate());

    if (ds.getFeaturesMaskArray() != null && ds.getFeaturesMaskArray().isAttached())
        ds.setFeaturesMaskArray(ds.getFeaturesMaskArray().migrate());

    if (ds.getLabelsMaskArray() != null && ds.getLabelsMaskArray().isAttached())
        ds.setLabelsMaskArray(ds.getLabelsMaskArray().migrate());
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:14,代码来源:ComputationGraph.java

示例7: testRocMultiToHtml

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
@Test
public void testRocMultiToHtml() throws Exception {
    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
                                    new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
                                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    NormalizerStandardize ns = new NormalizerStandardize();
    DataSet ds = iter.next();
    ns.fit(ds);
    ns.transform(ds);

    for (int i = 0; i < 30; i++) {
        net.fit(ds);
    }

    for (int numSteps : new int[] {20, 0}) {
        ROCMultiClass roc = new ROCMultiClass(numSteps);
        iter.reset();

        INDArray f = ds.getFeatures();
        INDArray l = ds.getLabels();
        INDArray out = net.output(f);
        roc.eval(l, out);


        String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
        System.out.println(str);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:36,代码来源:EvaluationToolsTests.java

示例8: main

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
/**
 * args[0] input: word2vecファイル名
 * args[1] input: train/test親フォルダ名
 * args[2] output: 学習モデル名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  int numInputs   = wvec.lookupTable().layerSize();
  int numOutputs  = 2; // FIXME positive or negative
  int batchSize   = 16;//100;
  int testBatch   = 64;
  int nEpochs     = 5000;
  int listenfreq  = 10;

  MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
      .seed(7485)
      .updater(Updater.RMSPROP) //ADADELTA
      .learningRate(0.001) //RMSPROP
      .rmsDecay(0.90) //RMSPROP
      //.rho(0.95) //ADADELTA
      .epsilon(1e-8) //ALL
      .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
      .weightInit(WeightInit.XAVIER)
      .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
      .gradientNormalizationThreshold(1.0)
      //.regularization(true)
      //.l2(1e-5)
      .list()
      .layer(0, new GravesLSTM.Builder()
          .nIn(numInputs).nOut(numInputs)
          .activation("softsign")
          .build())
      .layer(1, new RnnOutputLayer.Builder()
          .lossFunction(LossFunctions.LossFunction.MCXENT)
          .activation("softmax")
          .nIn(numInputs).nOut(numOutputs)
          .build())
      .pretrain(false).backprop(true).build();

  MultiLayerNetwork model = new MultiLayerNetwork(conf);
  model.init();
  model.setListeners(new ScoreIterationListener(listenfreq));


  LOG.info("Starting training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,testBatch,300,false),2);
  for( int i=0; i<nEpochs; i++ ){
    model.fit(train);
    train.reset();

    LOG.info("Epoch " + i + " complete. Starting evaluation:");
    Evaluation evaluation = new Evaluation();
    while(test.hasNext()) {
      DataSet t = test.next();
      INDArray features = t.getFeatures();
      INDArray lables = t.getLabels();
      INDArray inMask = t.getFeaturesMaskArray();
      INDArray outMask = t.getLabelsMaskArray();
      INDArray predicted = model.output(features,false,inMask,outMask);
      evaluation.evalTimeSeries(lables,predicted,outMask);
    }
    test.reset();
    LOG.info(evaluation.stats());

    LOG.info("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[2]), true);
  }
}
 
开发者ID:keigohtr,项目名称:sentiment-rnn,代码行数:78,代码来源:SentimentRecurrentTrainCmd.java

示例9: preProcess

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
@Override
public void preProcess(DataSet toPreProcess) {
    INDArray features = toPreProcess.getFeatures();
    this.preProcess(features);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:6,代码来源:VGG16ImagePreProcessor.java

示例10: testRocHtml

import org.nd4j.linalg.dataset.api.DataSet; //导入方法依赖的package包/类
@Test
public void testRocHtml() {

    DataSetIterator iter = new IrisDataSetIterator(150, 150);

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list()
                    .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
                                    new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX)
                                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    NormalizerStandardize ns = new NormalizerStandardize();
    DataSet ds = iter.next();
    ns.fit(ds);
    ns.transform(ds);

    INDArray newLabels = Nd4j.create(150, 2);
    newLabels.getColumn(0).assign(ds.getLabels().getColumn(0));
    newLabels.getColumn(0).addi(ds.getLabels().getColumn(1));
    newLabels.getColumn(1).assign(ds.getLabels().getColumn(2));
    ds.setLabels(newLabels);

    for (int i = 0; i < 30; i++) {
        net.fit(ds);
    }

    for (int numSteps : new int[] {20, 0}) {
        ROC roc = new ROC(numSteps);
        iter.reset();

        INDArray f = ds.getFeatures();
        INDArray l = ds.getLabels();
        INDArray out = net.output(f);
        roc.eval(l, out);


        String str = EvaluationTools.rocChartToHtml(roc);
        //            System.out.println(str);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:43,代码来源:EvaluationToolsTests.java


注:本文中的org.nd4j.linalg.dataset.api.DataSet.getFeatures方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。