当前位置: 首页>>代码示例>>Java>>正文


Java Updater类代码示例

本文整理汇总了Java中org.deeplearning4j.nn.api.Updater的典型用法代码示例。如果您正苦于以下问题:Java Updater类的具体用法?Java Updater怎么用?Java Updater使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。


Updater类属于org.deeplearning4j.nn.api包,在下文中一共展示了Updater类的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: fit

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/**
 * Fit the model
 *
 * @param input the examples to classify (one example in each row)
 * @param labels   the example labels(a binary outcome matrix)
 */
@Override
public void fit(INDArray input, INDArray labels) {
    setInput(input);
    setLabels(labels);
    applyDropOutIfNecessary(true);
    if (solver == null) {
        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        //Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
        Updater updater = solver.getOptimizer().getUpdater();
        int updaterStateSize = 0;
        Map<String, INDArray> paramTable = paramTable();
        for (Map.Entry<String, INDArray> entry : paramTable.entrySet()) {
            updaterStateSize += (int) conf().getLayer().getUpdaterByParam(entry.getKey())
                            .stateSize(entry.getValue().length());
        }
        if (updaterStateSize > 0)
            updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] {1, updaterStateSize}, Nd4j.order()),
                            true);
    }
    solver.optimize();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:28,代码来源:BaseOutputLayer.java

示例2: testSetGetUpdater2

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Test
public void testSetGetUpdater2() {
    //Same as above test, except that we are doing setUpdater on a new network
    Nd4j.getRandom().setSeed(12345L);
    double lr = 0.03;
    int nIn = 4;
    int nOut = 8;

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(lr,0.6)).list()
                    .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5)
                                    .updater(org.deeplearning4j.nn.conf.Updater.SGD).build())
                    .layer(1, new DenseLayer.Builder().nIn(5).nOut(6)
                                    .updater(new NoOp()).build())
                    .layer(2, new DenseLayer.Builder().nIn(6).nOut(7)
                                    .updater(org.deeplearning4j.nn.conf.Updater.ADAGRAD).build())
                    .layer(3, new OutputLayer.Builder().nIn(7).nOut(nOut)
                                    .updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).build())
                    .backprop(true).pretrain(false).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    Updater newUpdater = UpdaterCreator.getUpdater(net);
    net.setUpdater(newUpdater);
    assertTrue(newUpdater == net.getUpdater()); //Should be identical object
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:TestUpdaters.java

示例3: getFinalResult

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        Updater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }

    Nd4j.getExecutioner().commit();

    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) { //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData,
                    listenerStaticInfo, listenerUpdates);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:27,代码来源:ParameterAveragingTrainingWorker.java

示例4: clone

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/**
 * Clones the multilayernetwork
 * @return
 */
@Override
public MultiLayerNetwork clone() {
    MultiLayerConfiguration conf = this.layerWiseConfigurations.clone();
    MultiLayerNetwork ret = new MultiLayerNetwork(conf);
    ret.init(this.params().dup(), false);

    if (solver != null) {
        //If  solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
        Updater u = this.getUpdater();
        INDArray updaterState = u.getStateViewArray();
        if (updaterState != null) {
            ret.getUpdater().setStateViewArray(ret, updaterState.dup(), false);
        }
    }

    if (hasAFrozenLayer()) {
        //correct layers to frozen layers
        Layer[] clonedLayers = ret.getLayers();
        for (int i = 0; i < layers.length; i++) {
            if (layers[i] instanceof FrozenLayer) {
                clonedLayers[i] = new FrozenLayer(ret.getLayer(i));
            }
        }
        ret.setLayers(clonedLayers);
    }
    return ret;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:32,代码来源:MultiLayerNetwork.java

示例5: update

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/**
 * Assigns the parameters of this model to the ones specified by this
 * network. This is used in loading from input streams, factory methods, etc
 *
 * @param network the network to getFromOrigin parameters from
 */
public void update(MultiLayerNetwork network) {
    this.defaultConfiguration =
                    (network.defaultConfiguration != null ? network.defaultConfiguration.clone() : null);
    if (network.input != null)
        setInput(network.input.dup()); //Dup in case of dropout etc
    this.labels = network.labels;
    if (network.layers != null) {
        layers = new Layer[network.layers.length];
        for (int i = 0; i < layers.length; i++) {
            layers[i] = network.layers[i].clone();
        }
    } else {
        this.layers = null;
    }
    if (network.solver != null) {
        //Network updater state: should be cloned over also
        INDArray updaterView = network.getUpdater().getStateViewArray();
        if (updaterView != null) {
            //                Updater newUpdater = new MultiLayerUpdater(this, updaterView.dup());
            Updater newUpdater = new MultiLayerUpdater(this);
            newUpdater.setStateViewArray(this, updaterView.dup(), false);
            this.setUpdater(newUpdater);
        }
    } else {
        this.solver = null;
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:34,代码来源:MultiLayerNetwork.java

示例6: getUpdater

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/** Get the updater for this MultiLayerNetwork
 * @return Updater for MultiLayerNetwork
 */
public synchronized Updater getUpdater() {
    if (solver == null) {
        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
    }
    return solver.getOptimizer().getUpdater();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:11,代码来源:MultiLayerNetwork.java

示例7: setUpdater

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/** Set the updater for the MultiLayerNetwork */
public void setUpdater(Updater updater) {
    if (solver == null) {
        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
    }
    solver.getOptimizer().setUpdater(updater);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:8,代码来源:MultiLayerNetwork.java

示例8: getUpdater

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Override
public Updater getUpdater() {
    if (updater == null) {
        updater = UpdaterCreator.getUpdater(model);
    }
    return updater;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:8,代码来源:BaseOptimizer.java

示例9: clone

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Override
public Updater clone() {
    return new MultiLayerUpdater(network, null);
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:5,代码来源:MultiLayerUpdater.java

示例10: setUpdater

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Override
public void setUpdater(Updater updater) {
    this.updater = updater;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:5,代码来源:BaseOptimizer.java

示例11: getUpdater

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
Updater getUpdater(); 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:2,代码来源:ConvexOptimizer.java

示例12: setUpdater

import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
void setUpdater(Updater updater); 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:2,代码来源:ConvexOptimizer.java


注:本文中的org.deeplearning4j.nn.api.Updater类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。