本文整理汇总了Java中org.deeplearning4j.nn.api.Updater类的典型用法代码示例。如果您正苦于以下问题:Java Updater类的具体用法?Java Updater怎么用?Java Updater使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
Updater类属于org.deeplearning4j.nn.api包,在下文中一共展示了Updater类的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。
示例1: fit
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/**
* Fit the model
*
* @param input the examples to classify (one example in each row)
* @param labels the example labels(a binary outcome matrix)
*/
@Override
public void fit(INDArray input, INDArray labels) {
setInput(input);
setLabels(labels);
applyDropOutIfNecessary(true);
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
//Set the updater state view array. For MLN and CG, this is done by MultiLayerUpdater and ComputationGraphUpdater respectively
Updater updater = solver.getOptimizer().getUpdater();
int updaterStateSize = 0;
Map<String, INDArray> paramTable = paramTable();
for (Map.Entry<String, INDArray> entry : paramTable.entrySet()) {
updaterStateSize += (int) conf().getLayer().getUpdaterByParam(entry.getKey())
.stateSize(entry.getValue().length());
}
if (updaterStateSize > 0)
updater.setStateViewArray(this, Nd4j.createUninitialized(new int[] {1, updaterStateSize}, Nd4j.order()),
true);
}
solver.optimize();
}
示例2: testSetGetUpdater2
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Test
public void testSetGetUpdater2() {
//Same as above test, except that we are doing setUpdater on a new network
Nd4j.getRandom().setSeed(12345L);
double lr = 0.03;
int nIn = 4;
int nOut = 8;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(lr,0.6)).list()
.layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5)
.updater(org.deeplearning4j.nn.conf.Updater.SGD).build())
.layer(1, new DenseLayer.Builder().nIn(5).nOut(6)
.updater(new NoOp()).build())
.layer(2, new DenseLayer.Builder().nIn(6).nOut(7)
.updater(org.deeplearning4j.nn.conf.Updater.ADAGRAD).build())
.layer(3, new OutputLayer.Builder().nIn(7).nOut(nOut)
.updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).build())
.backprop(true).pretrain(false).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
Updater newUpdater = UpdaterCreator.getUpdater(net);
net.setUpdater(newUpdater);
assertTrue(newUpdater == net.getUpdater()); //Should be identical object
}
示例3: getFinalResult
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
INDArray updaterState = null;
if (saveUpdater) {
Updater u = network.getUpdater();
if (u != null)
updaterState = u.getStateViewArray();
}
Nd4j.getExecutioner().commit();
Collection<StorageMetaData> storageMetaData = null;
Collection<Persistable> listenerStaticInfo = null;
Collection<Persistable> listenerUpdates = null;
if (listenerRouterProvider != null) {
StatsStorageRouter r = listenerRouterProvider.getRouter();
if (r instanceof VanillaStatsStorageRouter) { //TODO this is ugly... need to find a better solution
VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
storageMetaData = ssr.getStorageMetaData();
listenerStaticInfo = ssr.getStaticInfo();
listenerUpdates = ssr.getUpdates();
}
}
return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData,
listenerStaticInfo, listenerUpdates);
}
示例4: clone
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/**
* Clones the multilayernetwork
* @return
*/
@Override
public MultiLayerNetwork clone() {
MultiLayerConfiguration conf = this.layerWiseConfigurations.clone();
MultiLayerNetwork ret = new MultiLayerNetwork(conf);
ret.init(this.params().dup(), false);
if (solver != null) {
//If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however
Updater u = this.getUpdater();
INDArray updaterState = u.getStateViewArray();
if (updaterState != null) {
ret.getUpdater().setStateViewArray(ret, updaterState.dup(), false);
}
}
if (hasAFrozenLayer()) {
//correct layers to frozen layers
Layer[] clonedLayers = ret.getLayers();
for (int i = 0; i < layers.length; i++) {
if (layers[i] instanceof FrozenLayer) {
clonedLayers[i] = new FrozenLayer(ret.getLayer(i));
}
}
ret.setLayers(clonedLayers);
}
return ret;
}
示例5: update
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/**
* Assigns the parameters of this model to the ones specified by this
* network. This is used in loading from input streams, factory methods, etc
*
* @param network the network to getFromOrigin parameters from
*/
public void update(MultiLayerNetwork network) {
this.defaultConfiguration =
(network.defaultConfiguration != null ? network.defaultConfiguration.clone() : null);
if (network.input != null)
setInput(network.input.dup()); //Dup in case of dropout etc
this.labels = network.labels;
if (network.layers != null) {
layers = new Layer[network.layers.length];
for (int i = 0; i < layers.length; i++) {
layers[i] = network.layers[i].clone();
}
} else {
this.layers = null;
}
if (network.solver != null) {
//Network updater state: should be cloned over also
INDArray updaterView = network.getUpdater().getStateViewArray();
if (updaterView != null) {
// Updater newUpdater = new MultiLayerUpdater(this, updaterView.dup());
Updater newUpdater = new MultiLayerUpdater(this);
newUpdater.setStateViewArray(this, updaterView.dup(), false);
this.setUpdater(newUpdater);
}
} else {
this.solver = null;
}
}
示例6: getUpdater
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/** Get the updater for this MultiLayerNetwork
* @return Updater for MultiLayerNetwork
*/
public synchronized Updater getUpdater() {
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
}
return solver.getOptimizer().getUpdater();
}
示例7: setUpdater
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
/** Set the updater for the MultiLayerNetwork */
public void setUpdater(Updater updater) {
if (solver == null) {
solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
}
solver.getOptimizer().setUpdater(updater);
}
示例8: getUpdater
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Override
public Updater getUpdater() {
if (updater == null) {
updater = UpdaterCreator.getUpdater(model);
}
return updater;
}
示例9: clone
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Override
public Updater clone() {
return new MultiLayerUpdater(network, null);
}
示例10: setUpdater
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
@Override
public void setUpdater(Updater updater) {
this.updater = updater;
}
示例11: getUpdater
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
Updater getUpdater();
示例12: setUpdater
import org.deeplearning4j.nn.api.Updater; //导入依赖的package包/类
void setUpdater(Updater updater);