本文整理汇总了Python中theano.compat.python2x.OrderedDict.get方法的典型用法代码示例。如果您正苦于以下问题:Python OrderedDict.get方法的具体用法?Python OrderedDict.get怎么用?Python OrderedDict.get使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类theano.compat.python2x.OrderedDict
的用法示例。
在下文中一共展示了OrderedDict.get方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_func
# 需要导入模块: from theano.compat.python2x import OrderedDict [as 别名]
# 或者: from theano.compat.python2x.OrderedDict import get [as 别名]
def get_func(learn_discriminator, learn_generator, dont_you_fucking_dare_touch_the_generator=False):
updates = OrderedDict()
assert (learn_discriminator or learn_generator) and not (learn_discriminator and learn_generator)
if learn_discriminator:
cur_params = model.discriminator.get_params()
else:
cur_params = model.generator.get_params()
def check():
for param in params:
if param not in cur_params:
assert param not in updates
cur_grads = OrderedDict()
for param in cur_params:
cur_grads[param] = grads[param]
for param in grads:
if grads[param].name is None and cost_value is not None:
grads[param].name = ('grad(%(costname)s, %(paramname)s)' %
{'costname': cost_value.name,
'paramname': param.name})
assert grads[param].dtype == param.dtype
cur_lr_scalers = OrderedDict()
for param in cur_params:
if param in lr_scalers:
lr_scaler = lr_scalers[param]
cur_lr_scalers[param] = lr_scaler
log.info('Parameter and initial learning rate summary:')
for param in cur_params:
param_name = param.name
if param_name is None:
param_name = 'anon_param'
lr = learning_rate.get_value() * cur_lr_scalers.get(param,1.)
log.info('\t' + param_name + ': ' + str(lr))
updates.update(self.learning_rule.get_updates(
learning_rate, cur_grads, cur_lr_scalers))
check()
for param in cur_params:
if updates[param].name is None:
updates[param].name = 'sgd_update(' + param.name + ')'
check()
model.modify_updates(updates)
check()
for param in cur_params:
update = updates[param]
if update.name is None:
update.name = 'censor(sgd_update(' + param.name + '))'
for update_val in get_debug_values(update):
if np.any(np.isinf(update_val)):
raise ValueError("debug value of %s contains infs" %
update.name)
if np.any(np.isnan(update_val)):
raise ValueError("debug value of %s contains nans" %
update.name)
check()
if dont_you_fucking_dare_touch_the_generator:
for param in model.generator.get_params():
assert param not in updates
with log_timing(log, 'Compiling sgd_update'):
return function(theano_args,
updates=updates,
name='sgd_update',
on_unused_input='ignore',
mode=self.theano_function_mode)