当前位置: 首页>>代码示例>>Java>>正文


Java IterationListener类代码示例

本文整理汇总了Java中org.deeplearning4j.optimize.api.IterationListener的典型用法代码示例。如果您正苦于以下问题:Java IterationListener类的具体用法?Java IterationListener怎么用?Java IterationListener使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


IterationListener类属于org.deeplearning4j.optimize.api包,在下文中一共展示了IterationListener类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: BarnesHutTsne

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
public BarnesHutTsne(int numDimensions, String simiarlityFunction, double theta, boolean invert, int maxIter,
                double realMin, double initialMomentum, double finalMomentum, double momentum,
                int switchMomentumIteration, boolean normalize, int stopLyingIteration, double tolerance,
                double learningRate, boolean useAdaGrad, double perplexity, IterationListener iterationListener,
                double minGain,int vpTreeWorkers) {
    this.maxIter = maxIter;
    this.realMin = realMin;
    this.initialMomentum = initialMomentum;
    this.finalMomentum = finalMomentum;
    this.momentum = momentum;
    this.normalize = normalize;
    this.useAdaGrad = useAdaGrad;
    this.stopLyingIteration = stopLyingIteration;
    this.learningRate = learningRate;
    this.switchMomentumIteration = switchMomentumIteration;
    this.tolerance = tolerance;
    this.perplexity = perplexity;
    this.minGain = minGain;
    this.numDimensions = numDimensions;
    this.simiarlityFunction = simiarlityFunction;
    this.theta = theta;
    this.iterationListener = iterationListener;
    this.invert = invert;
    this.vpTreeWorkers = vpTreeWorkers;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:BarnesHutTsne.java

示例2: getListener

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
/**
 * Get the iterationlistener
 *
 * @throws Exception
 */
protected List<IterationListener> getListener() throws Exception {
  int numSamples = trainData.numInstances();
  List<IterationListener> listeners = new ArrayList<>();

  // Initialize weka listener
  if (iterationListener instanceof weka.dl4j.listener.EpochListener) {
    int numEpochs = getNumEpochs();
    ((EpochListener) iterationListener)
        .init(
            trainData.numClasses(),
            numEpochs,
            numSamples,
            trainIterator,
            earlyStopping.getValDataSetIterator());
    ((EpochListener) iterationListener).setLogFile(logFile);
    listeners.add(iterationListener);
  } else {
    listeners.add(iterationListener);
  }
  return listeners;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:27,代码来源:Dl4jMlpClassifier.java

示例3: setListeners

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Override
public void setListeners(Collection<IterationListener> listeners) {
    if (iterationListeners == null)
        iterationListeners = new ArrayList<>();
    else
        iterationListeners.clear();
    if (trainingListeners == null)
        trainingListeners = new ArrayList<>();
    else
        trainingListeners.clear();

    if (listeners != null && !listeners.isEmpty()) {
        iterationListeners.addAll(listeners);
        for (IterationListener il : listeners) {
            if (il instanceof TrainingListener) {
                trainingListeners.add((TrainingListener) il);
            }
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:BasePretrainNetwork.java

示例4: testIterationListener

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Test
public void testIterationListener() {
    MultiLayerNetwork model1 = new MultiLayerNetwork(getConf());
    model1.init();
    model1.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(1)));

    MultiLayerNetwork model2 = new MultiLayerNetwork(getConf());
    model2.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(1)));
    model2.init();

    Layer[] l1 = model1.getLayers();
    for (int i = 0; i < l1.length; i++)
        assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1);

    Layer[] l2 = model2.getLayers();
    for (int i = 0; i < l2.length; i++)
        assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:MultiLayerNeuralNetConfigurationTest.java

示例5: setListeners

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Override
public void setListeners(Collection<IterationListener> listeners) {
    this.listeners = listeners;

    if (layers == null) {
        init();
    }
    for (Layer layer : layers) {
        layer.setListeners(listeners);
    }

    if (solver != null) {
        solver.setListeners(listeners);
    }

    this.trainingListeners.clear();
    if (listeners != null) {
        for (IterationListener il : listeners) {
            if (il instanceof TrainingListener) {
                this.trainingListeners.add((TrainingListener) il);
            }
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:25,代码来源:MultiLayerNetwork.java

示例6: instantiate

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Override
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
                                                   Collection<IterationListener> iterationListeners, int layerIndex,
                                                   INDArray layerParamsView, boolean initializeParams) {
    NeuralNetConfiguration c1 = conf.clone();
    NeuralNetConfiguration c2 = conf.clone();
    c1.setLayer(fwd);
    c2.setLayer(bwd);

    int n = layerParamsView.length() / 2;
    INDArray fp = layerParamsView.get(point(0), interval(0,n));
    INDArray bp = layerParamsView.get(point(0), interval(n, 2*n));
    org.deeplearning4j.nn.api.layers.RecurrentLayer f
            = (RecurrentLayer) fwd.instantiate(c1, iterationListeners, layerIndex, fp, initializeParams);

    org.deeplearning4j.nn.api.layers.RecurrentLayer b
            = (RecurrentLayer) bwd.instantiate(c2, iterationListeners, layerIndex, bp, initializeParams);

    BidirectionalLayer ret = new BidirectionalLayer(conf, f, b);
    Map<String, INDArray> paramTable = initializer().init(conf, layerParamsView, initializeParams);
    ret.setParamTable(paramTable);
    ret.setConf(conf);

    return ret;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:Bidirectional.java

示例7: instantiate

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<IterationListener> iterationListeners, int layerIndex,
                         INDArray layerParamsView, boolean initializeParams) {
    if (this.nIn != this.nOut) {
        throw new IllegalStateException("Element wise layer must have the same input and output size. Got nIn="
                + nIn + ", nOut=" + nOut);
    }
    org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer ret
            = new org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer(conf);
    ret.setListeners(iterationListeners);
    ret.setIndex(layerIndex);
    ret.setParamsViewArray(layerParamsView);
    Map<String, INDArray> paramTable = initializer().init(conf, layerParamsView, initializeParams);
    ret.setParamTable(paramTable);
    ret.setConf(conf);

    return ret;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ElementWiseMultiplicationLayer.java

示例8: instantiate

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<IterationListener> iterationListeners,
                int layerIndex, INDArray layerParamsView, boolean initializeParams) {
    LayerValidation.assertNInNOutSet("VariationalAutoencoder", getLayerName(), layerIndex, getNIn(), getNOut());

    org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret =
                    new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(conf);

    ret.setListeners(iterationListeners);
    ret.setIndex(layerIndex);
    ret.setParamsViewArray(layerParamsView);
    Map<String, INDArray> paramTable = initializer().init(conf, layerParamsView, initializeParams);
    ret.setParamTable(paramTable);
    ret.setConf(conf);
    return ret;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:VariationalAutoencoder.java

示例9: instantiate

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<IterationListener> iterationListeners,
                         int layerIndex, INDArray layerParamsView, boolean initializeParams) {
    LayerValidation.assertNInNOutSet("SeparableConvolution2D", getLayerName(), layerIndex, getNIn(), getNOut());

    org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer ret =
            new org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer(conf);
    ret.setListeners(iterationListeners);
    ret.setIndex(layerIndex);
    ret.setParamsViewArray(layerParamsView);
    Map<String, INDArray> paramTable = initializer().init(conf, layerParamsView, initializeParams);
    ret.setParamTable(paramTable);
    ret.setConf(conf);

    return ret;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:SeparableConvolution2D.java

示例10: addListeners

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
/**
 * This method ADDS additional IterationListener to existing listeners
 *
 * @param listeners Listeners to add
 */
@Override
public void addListeners(IterationListener... listeners) {
    if (this.listeners == null) {
        setListeners(listeners);
        return;
    }

    for (IterationListener listener : listeners) {
        this.listeners.add(listener);
        if (listener instanceof TrainingListener) {
            this.trainingListeners.add((TrainingListener) listener);
        }
    }

    if (solver != null) {
        solver.setListeners(this.listeners);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:24,代码来源:ComputationGraph.java

示例11: testListenersForModel

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
private static void testListenersForModel(Model model, List<IterationListener> listeners) {

        int nWorkers = 2;
        ParallelWrapper wrapper = new ParallelWrapper.Builder(model).workers(nWorkers).averagingFrequency(1)
                        .reportScoreAfterAveraging(true).build();

        if (listeners != null) {
            wrapper.setListeners(listeners);
        }

        List<DataSet> data = new ArrayList<>();
        for (int i = 0; i < nWorkers; i++) {
            data.add(new DataSet(Nd4j.rand(1, 10), Nd4j.rand(1, 10)));
        }

        DataSetIterator iter = new ExistingDataSetIterator(data);

        TestListener.clearCounts();
        wrapper.fit(iter);

        assertEquals(2, TestListener.workerIDs.size());
        assertEquals(1, TestListener.sessionIDs.size());
        assertEquals(2, TestListener.forwardPassCount.get());
        assertEquals(2, TestListener.backwardPassCount.get());
    }
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:TestListeners.java

示例12: configureListeners

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
private void configureListeners(Model m, int counter) {
    if (iterationListeners != null) {
        List<IterationListener> list = new ArrayList<>(iterationListeners.size());
        for (IterationListener l : iterationListeners) {
            if (listenerRouterProvider != null && l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                rl.setStorageRouter(listenerRouterProvider.getRouter());
                String workerID = UIDProvider.getJVMUID() + "_" + counter;
                rl.setWorkerID(workerID);
            }
            list.add(l); //Don't need to clone listeners: not from broadcast, so deserialization handles
        }
        if (m instanceof MultiLayerNetwork)
            ((MultiLayerNetwork) m).setListeners(list);
        else
            ((ComputationGraph) m).setListeners(list);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ParameterAveragingTrainingWorker.java

示例13: testAutoEncoder

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Test
public void testAutoEncoder() throws Exception {

    MnistDataFetcher fetcher = new MnistDataFetcher(true);
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).updater(new Sgd(0.1))
                    .layer(new org.deeplearning4j.nn.conf.layers.AutoEncoder.Builder().nIn(784).nOut(600)
                                    .corruptionLevel(0.6)
                                    .lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).build())
                    .build();


    fetcher.fetch(100);
    DataSet d2 = fetcher.next();

    INDArray input = d2.getFeatureMatrix();
    int numParams = conf.getLayer().initializer().numParams(conf);
    INDArray params = Nd4j.create(1, numParams);
    AutoEncoder da = (AutoEncoder) conf.getLayer().instantiate(conf,
                    Arrays.<IterationListener>asList(new ScoreIterationListener(1)), 0, params, true);
    assertEquals(da.params(), da.params());
    assertEquals(471784, da.params().length());
    da.setParams(da.params());
    da.setBackpropGradientsViewArray(Nd4j.create(1, params.length()));
    da.fit(input);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:AutoEncoderTest.java

示例14: provideModelConfigurationFactory

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Provides
public ModelConfigurationFactory provideModelConfigurationFactory(Context context,
                                                                  FederatedParams params,
                                                                  IterationListener iterationListener) {
    return new ModelConfigurationFactory(context, iterationListener, params.getBatchSize());

}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:8,代码来源:MainModule.java

示例15: addListener

import org.deeplearning4j.optimize.api.IterationListener; //导入依赖的package包/类
@Override
public void addListener (IterationListener listener)
{
    if (listener != null) {
        Collection<IterationListener> listeners = model.getListeners();

        if (!listeners.contains(listener)) {
            listeners.add(listener);
            model.setListeners(listeners);
        }
    }
}
 
开发者ID:Audiveris,项目名称:audiveris,代码行数:13,代码来源:DeepClassifier.java


注:本文中的org.deeplearning4j.optimize.api.IterationListener类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。