當前位置: 首頁>>代碼示例>>Java>>正文


Java Updater類代碼示例

本文整理匯總了Java中org.deeplearning4j.nn.api.Updater的典型用法代碼示例。如果您正苦於以下問題:Java Updater類的具體用法?Java Updater怎麽用?Java Updater使用的例子?那麽, 這裏精選的類代碼示例或許可以為您提供幫助。


Updater類屬於org.deeplearning4j.nn.api包,在下文中一共展示了Updater類的12個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Java代碼示例。

示例1: fit

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
/**
 * Fit the model
 *
 * @param input the examples to classify (one example in each row)
 * @param labels   the example labels(a binary outcome matrix)
 */
@Override
public void fit(INDArray input, INDArray labels) {
    setInput(input);
    setLabels(labels);
    applyDropOutIfNecessary(true);
    if (solver == null) {
        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        //Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
        Updater updater = solver.getOptimizer().getUpdater();
        int updaterStateSize = 0;
        Map<String, INDArray> paramTable = paramTable();
        for (Map.Entry<String, INDArray> entry : paramTable.entrySet()) {
            updaterStateSize += (int) conf().getLayer().getUpdaterByParam(entry.getKey())
                            .stateSize(entry.getValue().length());
        }
        if (updaterStateSize > 0)
            updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] {1, updaterStateSize}, Nd4j.order()),
                            true);
    }
    solver.optimize();
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:28,代碼來源:BaseOutputLayer.java

示例2: testSetGetUpdater2

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
@Test
public void testSetGetUpdater2() {
    //Same as above test, except that we are doing setUpdater on a new network
    Nd4j.getRandom().setSeed(12345L);
    double lr = 0.03;
    int nIn = 4;
    int nOut = 8;

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(lr,0.6)).list()
                    .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5)
                                    .updater(org.deeplearning4j.nn.conf.Updater.SGD).build())
                    .layer(1, new DenseLayer.Builder().nIn(5).nOut(6)
                                    .updater(new NoOp()).build())
                    .layer(2, new DenseLayer.Builder().nIn(6).nOut(7)
                                    .updater(org.deeplearning4j.nn.conf.Updater.ADAGRAD).build())
                    .layer(3, new OutputLayer.Builder().nIn(7).nOut(nOut)
                                    .updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).build())
                    .backprop(true).pretrain(false).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    Updater newUpdater = UpdaterCreator.getUpdater(net);
    net.setUpdater(newUpdater);
    assertTrue(newUpdater == net.getUpdater()); //Should be identical object
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:27,代碼來源:TestUpdaters.java

示例3: getFinalResult

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        Updater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }

    Nd4j.getExecutioner().commit();

    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) { //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData,
                    listenerStaticInfo, listenerUpdates);
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:27,代碼來源:ParameterAveragingTrainingWorker.java

示例4: clone

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
/**
 * Clones the multilayernetwork
 * @return
 */
@Override
public MultiLayerNetwork clone() {
    MultiLayerConfiguration conf = this.layerWiseConfigurations.clone();
    MultiLayerNetwork ret = new MultiLayerNetwork(conf);
    ret.init(this.params().dup(), false);

    if (solver != null) {
        //If  solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
        Updater u = this.getUpdater();
        INDArray updaterState = u.getStateViewArray();
        if (updaterState != null) {
            ret.getUpdater().setStateViewArray(ret, updaterState.dup(), false);
        }
    }

    if (hasAFrozenLayer()) {
        //correct layers to frozen layers
        Layer[] clonedLayers = ret.getLayers();
        for (int i = 0; i < layers.length; i++) {
            if (layers[i] instanceof FrozenLayer) {
                clonedLayers[i] = new FrozenLayer(ret.getLayer(i));
            }
        }
        ret.setLayers(clonedLayers);
    }
    return ret;
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:32,代碼來源:MultiLayerNetwork.java

示例5: update

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
/**
 * Assigns the parameters of this model to the ones specified by this
 * network. This is used in loading from input streams, factory methods, etc
 *
 * @param network the network to getFromOrigin parameters from
 */
public void update(MultiLayerNetwork network) {
    this.defaultConfiguration =
                    (network.defaultConfiguration != null ? network.defaultConfiguration.clone() : null);
    if (network.input != null)
        setInput(network.input.dup()); //Dup in case of dropout etc
    this.labels = network.labels;
    if (network.layers != null) {
        layers = new Layer[network.layers.length];
        for (int i = 0; i < layers.length; i++) {
            layers[i] = network.layers[i].clone();
        }
    } else {
        this.layers = null;
    }
    if (network.solver != null) {
        //Network updater state: should be cloned over also
        INDArray updaterView = network.getUpdater().getStateViewArray();
        if (updaterView != null) {
            //                Updater newUpdater = new MultiLayerUpdater(this, updaterView.dup());
            Updater newUpdater = new MultiLayerUpdater(this);
            newUpdater.setStateViewArray(this, updaterView.dup(), false);
            this.setUpdater(newUpdater);
        }
    } else {
        this.solver = null;
    }
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:34,代碼來源:MultiLayerNetwork.java

示例6: getUpdater

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
/** Get the updater for this MultiLayerNetwork
 * @return Updater for MultiLayerNetwork
 */
public synchronized Updater getUpdater() {
    if (solver == null) {
        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
    }
    return solver.getOptimizer().getUpdater();
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:11,代碼來源:MultiLayerNetwork.java

示例7: setUpdater

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
/** Set the updater for the MultiLayerNetwork */
public void setUpdater(Updater updater) {
    if (solver == null) {
        solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
    }
    solver.getOptimizer().setUpdater(updater);
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:8,代碼來源:MultiLayerNetwork.java

示例8: getUpdater

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
@Override
public Updater getUpdater() {
    if (updater == null) {
        updater = UpdaterCreator.getUpdater(model);
    }
    return updater;
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:8,代碼來源:BaseOptimizer.java

示例9: clone

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
@Override
public Updater clone() {
    return new MultiLayerUpdater(network, null);
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:5,代碼來源:MultiLayerUpdater.java

示例10: setUpdater

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
@Override
public void setUpdater(Updater updater) {
    this.updater = updater;
}
 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:5,代碼來源:BaseOptimizer.java

示例11: getUpdater

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
Updater getUpdater(); 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:2,代碼來源:ConvexOptimizer.java

示例12: setUpdater

import org.deeplearning4j.nn.api.Updater; //導入依賴的package包/類
void setUpdater(Updater updater); 
開發者ID:deeplearning4j,項目名稱:deeplearning4j,代碼行數:2,代碼來源:ConvexOptimizer.java


注:本文中的org.deeplearning4j.nn.api.Updater類示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。