当前位置: 首页>>代码示例>>Java>>正文


Java DataSetIterator.resetSupported方法代码示例

本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.resetSupported方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.resetSupported方法的具体用法?Java DataSetIterator.resetSupported怎么用?Java DataSetIterator.resetSupported使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在org.nd4j.linalg.dataset.api.iterator.DataSetIterator的用法示例。


在下文中一共展示了DataSetIterator.resetSupported方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: AsyncDataSetIterator

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue,
                boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
    if (queueSize < 2)
        queueSize = 2;

    this.deviceId = deviceId;
    this.callback = callback;
    this.useWorkspace = useWorkspace;
    this.buffer = queue;
    this.prefetchSize = queueSize;
    this.backedIterator = iterator;
    this.workspaceId = "ADSI_ITER-" + java.util.UUID.randomUUID().toString();

    if (iterator.resetSupported() && !iterator.hasNext())
        this.backedIterator.reset();

    this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, null);

    /**
     * We want to ensure, that background thread will have the same thread->device affinity, as master thread
     */

    Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
    thread.setDaemon(true);
    thread.start();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:AsyncDataSetIterator.java

示例2: DataSetIteratorSplitter

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, long totalExamples, double ratio) {
    if (!(ratio > 0.0 && ratio < 1.0))
        throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");

    if (totalExamples < 0)
        throw new ND4JIllegalStateException("totalExamples number should be positive value");

    if (!baseIterator.resetSupported())
        throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");


    this.backedIterator = baseIterator;
    this.totalExamples = totalExamples;
    this.ratio = ratio;
    this.numTrain = (long) (totalExamples * ratio);
    this.numTest = totalExamples - numTrain;

    log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:DataSetIteratorSplitter.java

示例3: SparkADSI

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace,
                DataSetCallback callback, Integer deviceId) {
    this();

    if (queueSize < 2)
        queueSize = 2;

    this.deviceId = deviceId;
    this.callback = callback;
    this.useWorkspace = useWorkspace;
    this.buffer = queue;
    this.prefetchSize = queueSize;
    this.backedIterator = iterator;
    this.workspaceId = "SADSI_ITER-" + java.util.UUID.randomUUID().toString();

    if (iterator.resetSupported())
        this.backedIterator.reset();

    context = TaskContext.get();

    this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null);

    /**
     * We want to ensure, that background thread will have the same thread->device affinity, as master thread
     */

    Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
    thread.setDaemon(true);
    thread.start();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:31,代码来源:SparkADSI.java

示例4: distributionsForInstances

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * The method to use when making predictions for test instances.
 *
 * @param insts the instances to get predictions for
 * @return the class probability estimates (if the class is nominal) or the numeric predictions
 *     (if it is numeric)
 * @throws Exception if something goes wrong at prediction time
 */
@Override
public double[][] distributionsForInstances(Instances insts) throws Exception {

  log.info("Calc. dist for {} instances", insts.numInstances());

  // Do we only have a ZeroR model?
  if (zeroR != null) {
    return zeroR.distributionsForInstances(insts);
  }

  // Process input data to have the same filters applied as the training data
  insts = applyFilters(insts);

  // Get predictions
  final DataSetIterator it = getDataSetIterator(insts, CacheMode.NONE);
  double[][] preds = new double[insts.numInstances()][insts.numClasses()];

  if (it.resetSupported()) {
    it.reset();
  }

  int offset = 0;
  boolean next = it.hasNext();

  // Get predictions batch-wise
  while (next) {
    final DataSet ds = it.next();
    final INDArray features = ds.getFeatureMatrix();
    final INDArray labelsMask = ds.getLabelsMaskArray();
    INDArray lastTimeStepIndices = Nd4j.argMax(labelsMask, 1);
    INDArray predBatch = model.outputSingle(features);
    int currentBatchSize = predBatch.size(0);
    for (int i = 0; i < currentBatchSize; i++) {
      int thisTimeSeriesLastIndex = lastTimeStepIndices.getInt(i);
      INDArray thisExampleProbabilities =
          predBatch.get(
              NDArrayIndex.point(i),
              NDArrayIndex.all(),
              NDArrayIndex.point(thisTimeSeriesLastIndex));
      for (int j = 0; j < insts.numClasses(); j++) {
        preds[i + offset][j] = thisExampleProbabilities.getDouble(j);
      }
    }

    offset += currentBatchSize; // add batchsize as offset
    boolean hasInstancesLeft = offset < insts.numInstances();
    next = it.hasNext() || hasInstancesLeft;
  }

  // Fix classes
  for (int i = 0; i < preds.length; i++) {
    // only normalise if we're dealing with classification
    if (preds[i].length > 1) {
      weka.core.Utils.normalize(preds[i]);
    } else {
      // Rescale numeric classes with the computed coefficients in the initialization phase
      preds[i][0] = preds[i][0] * x1 + x0;
    }
  }
  return preds;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:70,代码来源:RnnSequenceClassifier.java

示例5: pretrainLayer

import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
 * Perform layerwise unsupervised training on a single pre-trainable layer in the network (VAEs, Autoencoders, etc)<br>
 * If the specified layer index (0 to numLayers - 1) is not a pretrainable layer, this is a no-op.
 *
 * @param layerIdx Index of the layer to train (0 to numLayers-1)
 * @param iter Training data
 */
public void pretrainLayer(int layerIdx, DataSetIterator iter) {
    if (flattenedGradients == null) {
        initGradientsView();
    }
    if (!layerWiseConfigurations.isPretrain())
        return;
    if (layerIdx >= layers.length) {
        throw new IllegalArgumentException(
                        "Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + layers.length + ")");
    }

    Layer layer = layers[layerIdx];
    if (!layer.isPretrainLayer())
        return;

    if (!iter.hasNext() && iter.resetSupported()) {
        iter.reset();
    }

    MemoryWorkspace workspace =
                    layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                                    : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                                                    ComputationGraph.workspaceConfigurationExternal,
                                                    ComputationGraph.workspaceExternal);
    MemoryWorkspace cache =
                    layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
                                    : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
                                                    ComputationGraph.workspaceConfigurationCache,
                                                    ComputationGraph.workspaceCache);

    log.info("Starting unsupervised training on layer " + layerIdx);
    while (iter.hasNext()) {
        DataSet next = iter.next();

        try (MemoryWorkspace wsCache = cache.notifyScopeEntered()) {
            try (MemoryWorkspace ws = workspace.notifyScopeEntered()) {
                input = next.getFeatureMatrix();
                pretrainLayer(layerIdx, input);
            }
        }
    }

    int ec = getLayer(layerIdx).conf().getEpochCount() + 1;
    getLayer(layerIdx).conf().setEpochCount(ec);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:53,代码来源:MultiLayerNetwork.java


注:本文中的org.nd4j.linalg.dataset.api.iterator.DataSetIterator.resetSupported方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。