当前位置: 首页>>代码示例>>Java>>正文


Java IUpdater类代码示例

本文整理汇总了Java中org.nd4j.linalg.learning.config.IUpdater的典型用法代码示例。如果您正苦于以下问题:Java IUpdater类的具体用法?Java IUpdater怎么用?Java IUpdater使用的例子?那么, 这里精选的类代码示例或许可以为您提供帮助。


IUpdater类属于org.nd4j.linalg.learning.config包,在下文中一共展示了IUpdater类的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Java代码示例。

示例1: setLearningRate

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
private static void setLearningRate(MultiLayerNetwork net, int layerNumber, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) {

        Layer l = net.getLayer(layerNumber).conf().getLayer();
        if (l instanceof BaseLayer) {
            BaseLayer bl = (BaseLayer) l;
            IUpdater u = bl.getIUpdater();
            if (u != null && u.hasLearningRate()) {
                if (newLrSchedule != null) {
                    u.setLrAndSchedule(Double.NaN, newLrSchedule);
                } else {
                    u.setLrAndSchedule(newLr, null);
                }
            }

            //Need to refresh the updater - if we change the LR (or schedule) we may rebuild the updater blocks, which are
            // built by creating blocks of params with the same configuration
            if (refreshUpdater) {
                refreshUpdater(net);
            }
        }
    }
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:NetworkUtils.java

示例2: getGraphConfCNN

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
private static ComputationGraphConfiguration getGraphConfCNN(int seed, IUpdater updater) {
    Nd4j.getRandom().setSeed(seed);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder()
                    .addInputs("in")
                    .addLayer("0", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1)
                                    .padding(0, 0).activation(Activation.TANH).build(), "in")
                    .addLayer("1", new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1)
                                    .padding(0, 0).activation(Activation.TANH).build(), "0")
                    .addLayer("2", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10)
                                    .build(), "1")
                    .setOutputs("2").setInputTypes(InputType.convolutional(10, 10, 3)).pretrain(false)
                    .backprop(true).build();
    return conf;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:17,代码来源:TestCompareParameterAveragingSparkVsSingleMachine.java

示例3: testUpdaters

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
@Test
public void testUpdaters() {
    SparkDl4jMultiLayer sparkNet = getBasicNetwork();
    MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();

    netCopy.fit(data);
    IUpdater expectedUpdater = ((BaseLayer) netCopy.conf().getLayer()).getIUpdater();
    double expectedLR = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getLearningRate();
    double expectedMomentum = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getMomentum();

    IUpdater actualUpdater = ((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater();
    sparkNet.fit(sparkData);
    double actualLR = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getLearningRate();
    double actualMomentum = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getMomentum();

    assertEquals(expectedUpdater, actualUpdater);
    assertEquals(expectedLR, actualLR, 0.01);
    assertEquals(expectedMomentum, actualMomentum, 0.01);

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:TestSparkMultiLayerParameterAveraging.java

示例4: use

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
@OptionMetadata(
  displayName = "updater",
  description = "The updater to use (default = SGD).",
  commandLineParamName = "updater",
  commandLineParamSynopsis = "-updater <string>",
  displayOrder = 12
)
public IUpdater getUpdater() {
  return iUpdater;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:11,代码来源:NeuralNetConfiguration.java

示例5: updaterConfigurationsEquals

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
public static boolean updaterConfigurationsEquals(Layer layer1, String param1, Layer layer2, String param2) {
    org.deeplearning4j.nn.conf.layers.Layer l1 = layer1.conf().getLayer();
    org.deeplearning4j.nn.conf.layers.Layer l2 = layer2.conf().getLayer();
    IUpdater u1 = l1.getUpdaterByParam(param1);
    IUpdater u2 = l2.getUpdaterByParam(param2);

    //For updaters to be equal (and hence combinable), we require that:
    //(a) The updater-specific configurations are equal (inc. LR, LR/momentum schedules etc)
    //(b) If one or more of the params are pretrainable params, they are in the same layer
    //    This last point is necessary as we don't want to modify the pretrain gradient/updater state during
    //    backprop, or modify the pretrain gradient/updater state of one layer while training another
    if (!u1.equals(u2)) {
        //Different updaters or different config
        return false;
    }

    boolean isPretrainParam1 = layer1.conf().getLayer().isPretrainParam(param1);
    boolean isPretrainParam2 = layer2.conf().getLayer().isPretrainParam(param2);
    if (isPretrainParam1 || isPretrainParam2) {
        //One or both of params are pretrainable.
        //Either layers differ -> don't want to combine a pretrain updaters across layers
        //Or one is pretrain and the other isn't -> don't want to combine pretrain updaters within a layer
        return layer1 == layer2 && isPretrainParam1 && isPretrainParam2;
    }

    return true;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:28,代码来源:UpdaterUtils.java

示例6: getUpdaterByParam

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
@Override
public IUpdater getUpdaterByParam(String paramName) {
    switch (paramName) {
        case BatchNormalizationParamInitializer.BETA:
        case BatchNormalizationParamInitializer.GAMMA:
            return iUpdater;
        case BatchNormalizationParamInitializer.GLOBAL_MEAN:
        case BatchNormalizationParamInitializer.GLOBAL_VAR:
            return new NoOp();
        default:
            throw new IllegalArgumentException("Unknown parameter: \"" + paramName + "\"");
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:14,代码来源:BatchNormalization.java

示例7: getUpdaterByParam

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
/**
 * Get the updater for the given parameter. Typically the same updater will be used for all updaters, but this
 * is not necessarily the case
 *
 * @param paramName    Parameter name
 * @return             IUpdater for the parameter
 */
@Override
public IUpdater getUpdaterByParam(String paramName) {
    if(biasUpdater != null && initializer().isBiasParam(this, paramName)){
        return biasUpdater;
    }
    return iUpdater;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:15,代码来源:BaseLayer.java

示例8: getUpdaterByParam

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
@Override
public IUpdater getUpdaterByParam(String paramName) {
    // center loss utilizes alpha directly for this so any updater can be used for other layers
    switch (paramName) {
        case CenterLossParamInitializer.CENTER_KEY:
            return new NoOp();
        default:
            return iUpdater;
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:11,代码来源:CenterLossOutputLayer.java

示例9: getConf

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
private static MultiLayerConfiguration getConf(int seed, IUpdater updater) {
    Nd4j.getRandom().setSeed(seed);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list()
                    .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder()
                                    .lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build())
                    .pretrain(false).backprop(true).build();
    return conf;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:11,代码来源:TestCompareParameterAveragingSparkVsSingleMachine.java

示例10: getConfCNN

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
private static MultiLayerConfiguration getConfCNN(int seed, IUpdater updater) {
    Nd4j.getRandom().setSeed(seed);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list()
                    .layer(0, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0)
                                    .activation(Activation.TANH).build())
                    .layer(1, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0)
                                    .activation(Activation.TANH).build())
                    .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10)
                                    .build())
                    .setInputType(InputType.convolutional(10, 10, 3)).pretrain(false).backprop(true).build();
    return conf;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:15,代码来源:TestCompareParameterAveragingSparkVsSingleMachine.java

示例11: getGraphConf

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
private static ComputationGraphConfiguration getGraphConf(int seed, IUpdater updater) {
    Nd4j.getRandom().setSeed(seed);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder()
                    .addInputs("in")
                    .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1",
                                    new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10)
                                                    .nOut(10).build(),
                                    "0")
                    .setOutputs("1").pretrain(false).backprop(true).build();
    return conf;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:14,代码来源:TestCompareParameterAveragingSparkVsSingleMachine.java

示例12: setUpdater

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
public void setUpdater(IUpdater updater) {
  iUpdater = updater;
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:4,代码来源:NeuralNetConfiguration.java

示例13: copyConfigToLayer

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
private void copyConfigToLayer(String layerName, Layer layer) {

            if (layer.getIDropout() == null)
                layer.setIDropout(idropOut);

            if (layer instanceof BaseLayer) {
                BaseLayer bLayer = (BaseLayer) layer;
                if (Double.isNaN(bLayer.getL1()))
                    bLayer.setL1(l1);
                if (Double.isNaN(bLayer.getL2()))
                    bLayer.setL2(l2);
                if (bLayer.getActivationFn() == null)
                    bLayer.setActivationFn(activationFn);
                if (bLayer.getWeightInit() == null)
                    bLayer.setWeightInit(weightInit);
                if (Double.isNaN(bLayer.getBiasInit()))
                    bLayer.setBiasInit(biasInit);

                //Configure weight noise:
                if(weightNoise != null && ((BaseLayer) layer).getWeightNoise() == null){
                    ((BaseLayer) layer).setWeightNoise(weightNoise.clone());
                }

                //Configure updaters:
                if(iUpdater != null && bLayer.getIUpdater() == null){
                    bLayer.setIUpdater(iUpdater);
                }
                if(biasUpdater != null && bLayer.getBiasUpdater() == null){
                    bLayer.setBiasUpdater(biasUpdater);
                }

                if(bLayer.getIUpdater() == null && iUpdater == null && bLayer.initializer().numParams(bLayer) > 0){
                    //No updater set anywhere
                    IUpdater u = new Sgd();
                    bLayer.setIUpdater(u);
                    log.warn("*** No updater configuration is set for layer {} - defaulting to {} ***", layerName, u);
                }

                if (bLayer.getGradientNormalization() == null)
                    bLayer.setGradientNormalization(gradientNormalization);
                if (Double.isNaN(bLayer.getGradientNormalizationThreshold()))
                    bLayer.setGradientNormalizationThreshold(gradientNormalizationThreshold);
            }

            if (layer instanceof ActivationLayer){
                ActivationLayer al = (ActivationLayer)layer;
                if(al.getActivationFn() == null)
                    al.setActivationFn(activationFn);
            }
        }
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:51,代码来源:NeuralNetConfiguration.java

示例14: getUpdaterByParam

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
@Override
public IUpdater getUpdaterByParam(String paramName) {
    return null;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:5,代码来源:FrozenLayer.java

示例15: biasUpdater

import org.nd4j.linalg.learning.config.IUpdater; //导入依赖的package包/类
/**
 * Gradient updater configuration, for the biases only. If not set, biases will use the updater as
 * set by {@link #updater(IUpdater)}
 *
 * @param updater Updater to use for bias parameters
 */
public Builder biasUpdater(IUpdater updater){
    this.biasUpdater = updater;
    return this;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:11,代码来源:NeuralNetConfiguration.java


注:本文中的org.nd4j.linalg.learning.config.IUpdater类示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。