当前位置: 首页>>代码示例>>Python>>正文


Python Momentum.get_updates方法代码示例

本文整理汇总了Python中pylearn2.training_algorithms.learning_rule.Momentum.get_updates方法的典型用法代码示例。如果您正苦于以下问题:Python Momentum.get_updates方法的具体用法?Python Momentum.get_updates怎么用?Python Momentum.get_updates使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在pylearn2.training_algorithms.learning_rule.Momentum的用法示例。


在下文中一共展示了Momentum.get_updates方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: SGD

# 需要导入模块: from pylearn2.training_algorithms.learning_rule import Momentum [as 别名]
# 或者: from pylearn2.training_algorithms.learning_rule.Momentum import get_updates [as 别名]

#.........这里部分代码省略.........
            else:
                cur_params = model.generator.get_params()

            def check():
                for param in params:
                    if param not in cur_params:
                        assert param not in updates

            cur_grads = OrderedDict()
            for param in cur_params:
                cur_grads[param] = grads[param]

            for param in grads:
                if grads[param].name is None and cost_value is not None:
                    grads[param].name = ('grad(%(costname)s, %(paramname)s)' %
                                         {'costname': cost_value.name,
                                          'paramname': param.name})
                assert grads[param].dtype == param.dtype

            cur_lr_scalers = OrderedDict()
            for param in cur_params:
                if param in lr_scalers:
                    lr_scaler = lr_scalers[param]
                    cur_lr_scalers[param] = lr_scaler

            log.info('Parameter and initial learning rate summary:')
            for param in cur_params:
                param_name = param.name
                if param_name is None:
                    param_name = 'anon_param'
                lr = learning_rate.get_value() * cur_lr_scalers.get(param,1.)
                log.info('\t' + param_name + ': ' + str(lr))

            updates.update(self.learning_rule.get_updates(
                    learning_rate, cur_grads, cur_lr_scalers))
            check()

            for param in cur_params:
                if updates[param].name is None:
                    updates[param].name = 'sgd_update(' + param.name + ')'
            check()
            model.modify_updates(updates)
            check()
            for param in cur_params:
                update = updates[param]
                if update.name is None:
                    update.name = 'censor(sgd_update(' + param.name + '))'
                for update_val in get_debug_values(update):
                    if np.any(np.isinf(update_val)):
                        raise ValueError("debug value of %s contains infs" %
                                update.name)
                    if np.any(np.isnan(update_val)):
                        raise ValueError("debug value of %s contains nans" %
                                update.name)

            check()

            if dont_you_fucking_dare_touch_the_generator:
                for param in model.generator.get_params():
                    assert param not in updates

            with log_timing(log, 'Compiling sgd_update'):
                return function(theano_args,
                                           updates=updates,
                                           name='sgd_update',
                                           on_unused_input='ignore',
开发者ID:AdityoSanjaya,项目名称:adversarial,代码行数:70,代码来源:sgd_alt.py

示例2: SGD

# 需要导入模块: from pylearn2.training_algorithms.learning_rule import Momentum [as 别名]
# 或者: from pylearn2.training_algorithms.learning_rule.Momentum import get_updates [as 别名]

#.........这里部分代码省略.........
            if param.name is None:
                param.name = 'sgd_params[%d]' % i

        grads, updates = self.cost.get_gradients(model, nested_args,
                                                 ** fixed_var_descr.fixed_vars)

        for param in grads:
            assert param in params
        for param in params:
            assert param in grads

        for param in grads:
            if grads[param].name is None and cost_value is not None:
                grads[param].name = ('grad(%(costname)s, %(paramname)s)' %
                                     {'costname': cost_value.name,
                                      'paramname': param.name})

        lr_scalers = model.get_lr_scalers()

        for key in lr_scalers:
            if key not in params:
                raise ValueError("Tried to scale the learning rate on " +\
                        str(key)+" which is not an optimization parameter.")

        log.info('Parameter and initial learning rate summary:')
        for param in params:
            param_name = param.name
            if param_name is None:
                param_name = 'anon_param'
            lr = learning_rate.get_value() * lr_scalers.get(param,1.)
            log.info('\t' + param_name + ': ' + str(lr))

        if self.learning_rule:
            updates.update(self.learning_rule.get_updates(
                learning_rate, grads, lr_scalers))
        else:
            # Use standard SGD updates with fixed learning rate.
            updates.update( dict(safe_zip(params, [param - learning_rate * \
                lr_scalers.get(param, 1.) * grads[param]
                                    for param in params])))

        for param in params:
            if updates[param].name is None:
                updates[param].name = 'sgd_update(' + param.name + ')'
        model.censor_updates(updates)
        for param in params:
            update = updates[param]
            if update.name is None:
                update.name = 'censor(sgd_update(' + param.name + '))'
            for update_val in get_debug_values(update):
                if np.any(np.isinf(update_val)):
                    raise ValueError("debug value of %s contains infs" % update.name)
                if np.any(np.isnan(update_val)):
                    raise ValueError("debug value of %s contains nans" % update.name)


        with log_timing(log, 'Compiling sgd_update'):
            self.sgd_update = function(theano_args,
                                       updates=updates,
                                       name='sgd_update',
                                       on_unused_input='ignore',
                                       mode=self.theano_function_mode)
        self.params = params

    def train(self, dataset):
        if not hasattr(self, 'sgd_update'):
开发者ID:antran89,项目名称:pylearn2,代码行数:70,代码来源:sgd.py


注:本文中的pylearn2.training_algorithms.learning_rule.Momentum.get_updates方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。