本文整理汇总了Python中config.cfg.weight_decay方法的典型用法代码示例。如果您正苦于以下问题:Python cfg.weight_decay方法的具体用法?Python cfg.weight_decay怎么用?Python cfg.weight_decay使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类config.cfg
的用法示例。
在下文中一共展示了cfg.weight_decay方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: resnet_arg_scope
# 需要导入模块: from config import cfg [as 别名]
# 或者: from config.cfg import weight_decay [as 别名]
def resnet_arg_scope(bn_is_training,
bn_trainable,
trainable=True,
weight_decay=cfg.weight_decay,
batch_norm_decay=0.99,
batch_norm_epsilon=1e-9,
batch_norm_scale=True):
batch_norm_params = {
'is_training': bn_is_training,
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon,
'scale': batch_norm_scale,
'trainable': bn_trainable,
'updates_collections': ops.GraphKeys.UPDATE_OPS
}
with arg_scope(
[slim.conv2d],
weights_regularizer=regularizers.l2_regularizer(weight_decay),
weights_initializer=initializers.variance_scaling_initializer(),
trainable=trainable,
activation_fn=nn_ops.relu,
normalizer_fn=layers.batch_norm,
normalizer_params=batch_norm_params):
with arg_scope([layers.batch_norm], **batch_norm_params) as arg_sc:
return arg_sc
示例2: resnet_arg_scope
# 需要导入模块: from config import cfg [as 别名]
# 或者: from config.cfg import weight_decay [as 别名]
def resnet_arg_scope(bn_is_training,
bn_trainable,
trainable=True,
weight_decay=cfg.weight_decay,
weight_init = initializers.variance_scaling_initializer(),
batch_norm_decay=0.99,
batch_norm_epsilon=1e-9,
batch_norm_scale=True):
batch_norm_params = {
'is_training': bn_is_training,
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon,
'scale': batch_norm_scale,
'trainable': bn_trainable,
'updates_collections': ops.GraphKeys.UPDATE_OPS
}
with arg_scope(
[slim.conv2d, slim.conv2d_transpose],
weights_regularizer=regularizers.l2_regularizer(weight_decay),
weights_initializer=weight_init,
trainable=trainable,
activation_fn=nn_ops.relu,
normalizer_fn=layers.batch_norm,
normalizer_params=batch_norm_params):
with arg_scope([layers.batch_norm], **batch_norm_params) as arg_sc:
return arg_sc
示例3: main
# 需要导入模块: from config import cfg [as 别名]
# 或者: from config.cfg import weight_decay [as 别名]
def main(args):
# create checkpoint dir
if not isdir(args.checkpoint):
mkdir_p(args.checkpoint)
# create model
model = network.__dict__[cfg.model](cfg.output_shape, cfg.num_class, pretrained = True)
model = torch.nn.DataParallel(model).cuda()
# define loss function (criterion) and optimizer
criterion1 = torch.nn.MSELoss().cuda() # for Global loss
criterion2 = torch.nn.MSELoss(reduce=False).cuda() # for refine loss
optimizer = torch.optim.Adam(model.parameters(),
lr = cfg.lr,
weight_decay=cfg.weight_decay)
if args.resume:
if isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
pretrained_dict = checkpoint['state_dict']
model.load_state_dict(pretrained_dict)
args.start_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
logger = Logger(join(args.checkpoint, 'log.txt'), resume=True)
else:
print("=> no checkpoint found at '{}'".format(args.resume))
else:
logger = Logger(join(args.checkpoint, 'log.txt'))
logger.set_names(['Epoch', 'LR', 'Train Loss'])
cudnn.benchmark = True
print(' Total params: %.2fMB' % (sum(p.numel() for p in model.parameters())/(1024*1024)*4))
train_loader = torch.utils.data.DataLoader(
MscocoMulti(cfg),
batch_size=cfg.batch_size*args.num_gpus, shuffle=True,
num_workers=args.workers, pin_memory=True)
for epoch in range(args.start_epoch, args.epochs):
lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch, cfg.lr_gamma)
print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
# train for one epoch
train_loss = train(train_loader, model, [criterion1, criterion2], optimizer)
print('train_loss: ',train_loss)
# append logger file
logger.append([epoch + 1, lr, train_loss])
save_model({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
}, checkpoint=args.checkpoint)
logger.close()