本文整理汇总了Scala中org.deeplearning4j.nn.conf.Updater类的典型用法代码示例。如果您正苦于以下问题:Scala Updater类的具体用法?Scala Updater怎么用?Scala Updater使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
在下文中一共展示了Updater类的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Scala代码示例。
示例1: MultiLayerNetworkExternalErrors
//设置package包名称以及导入依赖的类
package org.dl4scala.examples.misc.externalerrors
import org.deeplearning4j.nn.conf.layers.DenseLayer
import org.deeplearning4j.nn.conf.{NeuralNetConfiguration, Updater}
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.factory.Nd4j
object MultiLayerNetworkExternalErrors {
def main(array: Array[String]): Unit = {
//Create the model
val nIn = 4
val nOut = 3
Nd4j.getRandom.setSeed(12345)
val conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(Updater.NESTEROVS)
.learningRate(0.1)
.list()
.layer(0, new DenseLayer.Builder().nIn(nIn).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
.backprop(true).pretrain(false)
.build()
val model = new MultiLayerNetwork(conf)
model.init()
//Calculate gradient with respect to an external error//Calculate gradient with respect to an external error
val minibatch = 32
val input = Nd4j.rand(minibatch, nIn)
val output = model.output(input) //Do forward pass. Normally: calculate the error based on this
val externalError = Nd4j.rand(minibatch, nOut)
val p = model.backpropGradient(externalError) //Calculate backprop gradient based on error array
//Update the gradient: apply learning rate, momentum, etc
//This modifies the Gradient object in-place
val gradient = p.getFirst
val iteration = 0
model.getUpdater.update(model, gradient, iteration, minibatch)
//Get a row vector gradient array, and apply it to the parameters to update the model
val updateVector = gradient.gradient
model.params.subi(updateVector)
}
}