本文整理汇总了Python中pylearn2.training_algorithms.learning_rule.Momentum.get_updates方法的典型用法代码示例。如果您正苦于以下问题:Python Momentum.get_updates方法的具体用法?Python Momentum.get_updates怎么用?Python Momentum.get_updates使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类pylearn2.training_algorithms.learning_rule.Momentum
的用法示例。
在下文中一共展示了Momentum.get_updates方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: SGD
# 需要导入模块: from pylearn2.training_algorithms.learning_rule import Momentum [as 别名]
# 或者: from pylearn2.training_algorithms.learning_rule.Momentum import get_updates [as 别名]
#.........这里部分代码省略.........
else:
cur_params = model.generator.get_params()
def check():
for param in params:
if param not in cur_params:
assert param not in updates
cur_grads = OrderedDict()
for param in cur_params:
cur_grads[param] = grads[param]
for param in grads:
if grads[param].name is None and cost_value is not None:
grads[param].name = ('grad(%(costname)s, %(paramname)s)' %
{'costname': cost_value.name,
'paramname': param.name})
assert grads[param].dtype == param.dtype
cur_lr_scalers = OrderedDict()
for param in cur_params:
if param in lr_scalers:
lr_scaler = lr_scalers[param]
cur_lr_scalers[param] = lr_scaler
log.info('Parameter and initial learning rate summary:')
for param in cur_params:
param_name = param.name
if param_name is None:
param_name = 'anon_param'
lr = learning_rate.get_value() * cur_lr_scalers.get(param,1.)
log.info('\t' + param_name + ': ' + str(lr))
updates.update(self.learning_rule.get_updates(
learning_rate, cur_grads, cur_lr_scalers))
check()
for param in cur_params:
if updates[param].name is None:
updates[param].name = 'sgd_update(' + param.name + ')'
check()
model.modify_updates(updates)
check()
for param in cur_params:
update = updates[param]
if update.name is None:
update.name = 'censor(sgd_update(' + param.name + '))'
for update_val in get_debug_values(update):
if np.any(np.isinf(update_val)):
raise ValueError("debug value of %s contains infs" %
update.name)
if np.any(np.isnan(update_val)):
raise ValueError("debug value of %s contains nans" %
update.name)
check()
if dont_you_fucking_dare_touch_the_generator:
for param in model.generator.get_params():
assert param not in updates
with log_timing(log, 'Compiling sgd_update'):
return function(theano_args,
updates=updates,
name='sgd_update',
on_unused_input='ignore',
示例2: SGD
# 需要导入模块: from pylearn2.training_algorithms.learning_rule import Momentum [as 别名]
# 或者: from pylearn2.training_algorithms.learning_rule.Momentum import get_updates [as 别名]
#.........这里部分代码省略.........
if param.name is None:
param.name = 'sgd_params[%d]' % i
grads, updates = self.cost.get_gradients(model, nested_args,
** fixed_var_descr.fixed_vars)
for param in grads:
assert param in params
for param in params:
assert param in grads
for param in grads:
if grads[param].name is None and cost_value is not None:
grads[param].name = ('grad(%(costname)s, %(paramname)s)' %
{'costname': cost_value.name,
'paramname': param.name})
lr_scalers = model.get_lr_scalers()
for key in lr_scalers:
if key not in params:
raise ValueError("Tried to scale the learning rate on " +\
str(key)+" which is not an optimization parameter.")
log.info('Parameter and initial learning rate summary:')
for param in params:
param_name = param.name
if param_name is None:
param_name = 'anon_param'
lr = learning_rate.get_value() * lr_scalers.get(param,1.)
log.info('\t' + param_name + ': ' + str(lr))
if self.learning_rule:
updates.update(self.learning_rule.get_updates(
learning_rate, grads, lr_scalers))
else:
# Use standard SGD updates with fixed learning rate.
updates.update( dict(safe_zip(params, [param - learning_rate * \
lr_scalers.get(param, 1.) * grads[param]
for param in params])))
for param in params:
if updates[param].name is None:
updates[param].name = 'sgd_update(' + param.name + ')'
model.censor_updates(updates)
for param in params:
update = updates[param]
if update.name is None:
update.name = 'censor(sgd_update(' + param.name + '))'
for update_val in get_debug_values(update):
if np.any(np.isinf(update_val)):
raise ValueError("debug value of %s contains infs" % update.name)
if np.any(np.isnan(update_val)):
raise ValueError("debug value of %s contains nans" % update.name)
with log_timing(log, 'Compiling sgd_update'):
self.sgd_update = function(theano_args,
updates=updates,
name='sgd_update',
on_unused_input='ignore',
mode=self.theano_function_mode)
self.params = params
def train(self, dataset):
if not hasattr(self, 'sgd_update'):