本文整理汇总了Java中org.nd4j.linalg.dataset.api.iterator.DataSetIterator.resetSupported方法的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator.resetSupported方法的具体用法?Java DataSetIterator.resetSupported怎么用?Java DataSetIterator.resetSupported使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类org.nd4j.linalg.dataset.api.iterator.DataSetIterator
的用法示例。
在下文中一共展示了DataSetIterator.resetSupported方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: AsyncDataSetIterator
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public AsyncDataSetIterator(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue,
boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
if (queueSize < 2)
queueSize = 2;
this.deviceId = deviceId;
this.callback = callback;
this.useWorkspace = useWorkspace;
this.buffer = queue;
this.prefetchSize = queueSize;
this.backedIterator = iterator;
this.workspaceId = "ADSI_ITER-" + java.util.UUID.randomUUID().toString();
if (iterator.resetSupported() && !iterator.hasNext())
this.backedIterator.reset();
this.thread = new AsyncPrefetchThread(buffer, iterator, terminator, null);
/**
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
*/
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
thread.setDaemon(true);
thread.start();
}
示例2: DataSetIteratorSplitter
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, long totalExamples, double ratio) {
if (!(ratio > 0.0 && ratio < 1.0))
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
if (totalExamples < 0)
throw new ND4JIllegalStateException("totalExamples number should be positive value");
if (!baseIterator.resetSupported())
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
this.backedIterator = baseIterator;
this.totalExamples = totalExamples;
this.ratio = ratio;
this.numTrain = (long) (totalExamples * ratio);
this.numTest = totalExamples - numTrain;
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
示例3: SparkADSI
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace,
DataSetCallback callback, Integer deviceId) {
this();
if (queueSize < 2)
queueSize = 2;
this.deviceId = deviceId;
this.callback = callback;
this.useWorkspace = useWorkspace;
this.buffer = queue;
this.prefetchSize = queueSize;
this.backedIterator = iterator;
this.workspaceId = "SADSI_ITER-" + java.util.UUID.randomUUID().toString();
if (iterator.resetSupported())
this.backedIterator.reset();
context = TaskContext.get();
this.thread = new SparkPrefetchThread(buffer, iterator, terminator, null);
/**
* We want to ensure, that background thread will have the same thread->device affinity, as master thread
*/
Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId);
thread.setDaemon(true);
thread.start();
}
示例4: distributionsForInstances
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* The method to use when making predictions for test instances.
*
* @param insts the instances to get predictions for
* @return the class probability estimates (if the class is nominal) or the numeric predictions
* (if it is numeric)
* @throws Exception if something goes wrong at prediction time
*/
@Override
public double[][] distributionsForInstances(Instances insts) throws Exception {
log.info("Calc. dist for {} instances", insts.numInstances());
// Do we only have a ZeroR model?
if (zeroR != null) {
return zeroR.distributionsForInstances(insts);
}
// Process input data to have the same filters applied as the training data
insts = applyFilters(insts);
// Get predictions
final DataSetIterator it = getDataSetIterator(insts, CacheMode.NONE);
double[][] preds = new double[insts.numInstances()][insts.numClasses()];
if (it.resetSupported()) {
it.reset();
}
int offset = 0;
boolean next = it.hasNext();
// Get predictions batch-wise
while (next) {
final DataSet ds = it.next();
final INDArray features = ds.getFeatureMatrix();
final INDArray labelsMask = ds.getLabelsMaskArray();
INDArray lastTimeStepIndices = Nd4j.argMax(labelsMask, 1);
INDArray predBatch = model.outputSingle(features);
int currentBatchSize = predBatch.size(0);
for (int i = 0; i < currentBatchSize; i++) {
int thisTimeSeriesLastIndex = lastTimeStepIndices.getInt(i);
INDArray thisExampleProbabilities =
predBatch.get(
NDArrayIndex.point(i),
NDArrayIndex.all(),
NDArrayIndex.point(thisTimeSeriesLastIndex));
for (int j = 0; j < insts.numClasses(); j++) {
preds[i + offset][j] = thisExampleProbabilities.getDouble(j);
}
}
offset += currentBatchSize; // add batchsize as offset
boolean hasInstancesLeft = offset < insts.numInstances();
next = it.hasNext() || hasInstancesLeft;
}
// Fix classes
for (int i = 0; i < preds.length; i++) {
// only normalise if we're dealing with classification
if (preds[i].length > 1) {
weka.core.Utils.normalize(preds[i]);
} else {
// Rescale numeric classes with the computed coefficients in the initialization phase
preds[i][0] = preds[i][0] * x1 + x0;
}
}
return preds;
}
示例5: pretrainLayer
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; //导入方法依赖的package包/类
/**
* Perform layerwise unsupervised training on a single pre-trainable layer in the network (VAEs, Autoencoders, etc)<br>
* If the specified layer index (0 to numLayers - 1) is not a pretrainable layer, this is a no-op.
*
* @param layerIdx Index of the layer to train (0 to numLayers-1)
* @param iter Training data
*/
public void pretrainLayer(int layerIdx, DataSetIterator iter) {
if (flattenedGradients == null) {
initGradientsView();
}
if (!layerWiseConfigurations.isPretrain())
return;
if (layerIdx >= layers.length) {
throw new IllegalArgumentException(
"Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + layers.length + ")");
}
Layer layer = layers[layerIdx];
if (!layer.isPretrainLayer())
return;
if (!iter.hasNext() && iter.resetSupported()) {
iter.reset();
}
MemoryWorkspace workspace =
layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
ComputationGraph.workspaceConfigurationExternal,
ComputationGraph.workspaceExternal);
MemoryWorkspace cache =
layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? new DummyWorkspace()
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
ComputationGraph.workspaceConfigurationCache,
ComputationGraph.workspaceCache);
log.info("Starting unsupervised training on layer " + layerIdx);
while (iter.hasNext()) {
DataSet next = iter.next();
try (MemoryWorkspace wsCache = cache.notifyScopeEntered()) {
try (MemoryWorkspace ws = workspace.notifyScopeEntered()) {
input = next.getFeatureMatrix();
pretrainLayer(layerIdx, input);
}
}
}
int ec = getLayer(layerIdx).conf().getEpochCount() + 1;
getLayer(layerIdx).conf().setEpochCount(ec);
}