本文整理匯總了Python中torch.nn.SyncBatchNorm方法的典型用法代碼示例。如果您正苦於以下問題:Python nn.SyncBatchNorm方法的具體用法?Python nn.SyncBatchNorm怎麽用?Python nn.SyncBatchNorm使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.nn
的用法示例。
在下文中一共展示了nn.SyncBatchNorm方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: forward
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def forward(self, input):
if get_world_size() == 1 or not self.training:
return super().forward(input)
assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
C = input.shape[1]
mean = torch.mean(input, dim=[0, 2, 3])
meansqr = torch.mean(input * input, dim=[0, 2, 3])
vec = torch.cat([mean, meansqr], dim=0)
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
mean, meansqr = torch.split(vec, C)
var = meansqr - mean * mean
self.running_mean += self.momentum * (mean.detach() - self.running_mean)
self.running_var += self.momentum * (var.detach() - self.running_var)
invstd = torch.rsqrt(var + self.eps)
scale = self.weight * invstd
bias = self.bias - mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return input * scale + bias
示例2: fuse_module
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def fuse_module(m):
last_conv = None
last_conv_name = None
for name, child in m.named_children():
if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)):
if last_conv is None: # only fuse BN that is after Conv
continue
fused_conv = fuse_conv_bn(last_conv, child)
m._modules[last_conv_name] = fused_conv
# To reduce changes, set BN as Identity instead of deleting it.
m._modules[name] = nn.Identity()
last_conv = None
elif isinstance(child, nn.Conv2d):
last_conv = child
last_conv_name = name
else:
fuse_module(child)
return m
示例3: get_norm
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def get_norm(norm):
"""
Args:
norm (str or callable):
Returns:
nn.Module or None: the normalization layer
"""
support_norm_type = ['BN', 'SyncBN', 'FrozenBN', 'GN', 'nnSyncBN']
assert norm in support_norm_type, 'Unknown norm type {}, support norm types are {}'.format(
norm, support_norm_type)
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"BN": nn.BatchNorm2d,
"SyncBN": NaiveSyncBatchNorm,
"FrozenBN": FrozenBatchNorm2d,
"GN": groupNorm,
"nnSyncBN": nn.SyncBatchNorm, # keep for debugging
}[norm]
return norm
示例4: get_norm
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def get_norm(norm, out_channels):
"""
Args:
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
or a callable that takes a channel number and returns
the normalization layer as a nn.Module.
Returns:
nn.Module or None: the normalization layer
"""
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"BN": BatchNorm2d,
# Fixed in https://github.com/pytorch/pytorch/pull/36382
"SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
"FrozenBN": FrozenBatchNorm2d,
"GN": lambda channels: nn.GroupNorm(32, channels),
# for debugging:
"nnSyncBN": nn.SyncBatchNorm,
"naiveSyncBN": NaiveSyncBatchNorm,
}[norm]
return norm(out_channels)
示例5: get_norm
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def get_norm(norm, out_channels):
"""
Args:
norm (str or callable):
Returns:
nn.Module or None: the normalization layer
"""
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"BN": BatchNorm2d,
"SyncBN": NaiveSyncBatchNorm,
"FrozenBN": FrozenBatchNorm2d,
"GN": lambda channels: nn.GroupNorm(32, channels),
"nnSyncBN": nn.SyncBatchNorm, # keep for debugging
}[norm]
return norm(out_channels)
示例6: forward
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def forward(self, input):
if comm.get_world_size() == 1 or not self.training:
return super().forward(input)
assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
C = input.shape[1]
mean = torch.mean(input, dim=[0, 2, 3])
meansqr = torch.mean(input * input, dim=[0, 2, 3])
vec = torch.cat([mean, meansqr], dim=0)
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
mean, meansqr = torch.split(vec, C)
var = meansqr - mean * mean
self.running_mean += self.momentum * (mean.detach() - self.running_mean)
self.running_var += self.momentum * (var.detach() - self.running_var)
invstd = torch.rsqrt(var + self.eps)
scale = self.weight * invstd
bias = self.bias - mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return input * scale + bias
示例7: forward
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def forward(self, input):
if comm.get_world_size() == 1 or not self.training:
return super().forward(input)
assert input.shape[0] > 0, "SyncBatchNorm does not support empty input"
C = input.shape[1]
mean = torch.mean(input, dim=[0, 2, 3])
meansqr = torch.mean(input * input, dim=[0, 2, 3])
vec = torch.cat([mean, meansqr], dim=0)
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
mean, meansqr = torch.split(vec, C)
var = meansqr - mean * mean
self.running_mean += self.momentum * (mean.detach() - self.running_mean)
self.running_var += self.momentum * (var.detach() - self.running_var)
invstd = torch.rsqrt(var + self.eps)
scale = self.weight * invstd
bias = self.bias - mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
return input * scale + bias
示例8: set_batch_norm_attr
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def set_batch_norm_attr(self, named_modules, attr, value):
for m in named_modules:
if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm):
setattr(m[1], attr, value)
示例9: convert_frozen_batchnorm
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def convert_frozen_batchnorm(cls, module):
"""
Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
Args:
module (torch.nn.Module):
Returns:
If module is BatchNorm/SyncBatchNorm, returns a new module.
Otherwise, in-place convert module and return it.
Similar to convert_sync_batchnorm in
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
"""
bn_module = nn.modules.batchnorm
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
res = module
if isinstance(module, bn_module):
res = cls(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data + module.eps
else:
for name, child in module.named_children():
new_child = cls.convert_frozen_batchnorm(child)
if new_child is not child:
res.add_module(name, new_child)
return res
示例10: _set_batch_norm_attr
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def _set_batch_norm_attr(named_modules, attr, value):
for m in named_modules:
if isinstance(m[1], (nn.BatchNorm2d, nn.SyncBatchNorm)):
setattr(m[1], attr, value)
示例11: __init__
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def __init__(self, args):
self.args = args
self.device = torch.device(args.device)
# image transform
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])
# dataset and dataloader
val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval', transform=input_transform)
val_sampler = make_data_sampler(val_dataset, False, args.distributed)
val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1)
self.val_loader = data.DataLoader(dataset=val_dataset,
batch_sampler=val_batch_sampler,
num_workers=args.workers,
pin_memory=True)
# create network
BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
aux=args.aux, pretrained=True, pretrained_base=False,
local_rank=args.local_rank,
norm_layer=BatchNorm2d).to(self.device)
if args.distributed:
self.model = nn.parallel.DistributedDataParallel(self.model,
device_ids=[args.local_rank], output_device=args.local_rank)
self.model.to(self.device)
self.metric = SegmentationMetric(val_dataset.num_class)
示例12: __init__
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
if cfg['GN']:
self.bn = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-5)
elif cfg['syncBN']:
self.bn = nn.SyncBatchNorm(out_channels, eps=1e-5)
else:
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
示例13: __init__
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def __init__(self, C_in, C_out, affine=True, bn=False, **kwargs):
super(Normal_Relu_Conv, self).__init__()
if not bn:
op = nn.Sequential(
# nn.ReLU(),
nn.Conv2d(C_in, C_in, bias=True, **kwargs),
)
else:
if cfg['GN']:
bn_layer = nn.GroupNorm(32, C_out)
elif cfg["syncBN"]:
bn_layer = nn.SyncBatchNorm(C_out)
else:
bn_layer = nn.BatchNorm2d(C_out)
op = nn.Sequential(
# nn.ReLU(),
nn.Conv2d(C_in, C_in, bias=False, **kwargs),
bn_layer,
)
if RELU_FIRST:
self.op = nn.Sequential()
self.op.add_module('0', nn.ReLU())
for i in range(1, len(op)+1):
self.op.add_module(str(i), op[i-1])
else:
self.op = op
self.op.add_module(str(len(op)), nn.ReLU())
# self.op = op
示例14: forward
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def forward(self, weights, temp_coeff=1.0):
gumbel = -1e-3 * torch.log(-torch.log(torch.rand_like(weights))).to(weights.device)
weights = _GumbelSoftMax.apply((weights + gumbel) / temp_coeff)
return weights
# class D_Conv(nn.Module):
# """ Deformable Conv V2 """
# def __init__(self, C_in, C_out, kernel_size, padding, affine=True, bn=False):
# super(D_Conv, self).__init__()
# if bn:
# if cfg["syncBN"]:
# bn_layer = nn.SyncBatchNorm(C_out)
# else:
# bn_layer = nn.BatchNorm2d(C_out)
# self.op = nn.Sequential(
# nn.ReLU(inplace=False),
# DCN(
# C_in, C_in, kernel_size=kernel_size, padding=padding, stride=1, deformable_groups=C_in, groups=C_in
# ),
# nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
# bn_layer,
# )
# else:
# self.op = nn.Sequential(
# nn.ReLU(inplace=False),
# DCN(
# C_in, C_in, kernel_size=kernel_size, padding=padding, stride=1, deformable_groups=C_in, groups=C_in
# ),
# nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=True),
# )
# def forward(self, x):
# return self.op(x)
示例15: __init__
# 需要導入模塊: from torch import nn [as 別名]
# 或者: from torch.nn import SyncBatchNorm [as 別名]
def __init__(self, kernel_size, full_input_size, full_output_size, curr_vtx_id=None, args=None):
super(DynamicReLUConvBN, self).__init__()
self.args = args
padding = 1 if kernel_size == 3 else 0
# assign layers.
self.relu = nn.ReLU(inplace=False)
self.conv = DynamicConv2d(
full_input_size, full_output_size, kernel_size, padding=padding, bias=False,
dynamic_conv_method=args.dynamic_conv_method, dynamic_conv_dropoutw=args.dynamic_conv_dropoutw
)
self.curr_vtx_id = curr_vtx_id
tracking_stat = args.wsbn_track_stat
if args.wsbn_sync:
# logging.debug("Using sync bn.")
self.bn = SyncBatchNorm(full_output_size, momentum=base_ops.BN_MOMENTUM, eps=base_ops.BN_EPSILON,
track_running_stats=tracking_stat)
else:
self.bn = nn.BatchNorm2d(full_output_size, momentum=base_ops.BN_MOMENTUM, eps=base_ops.BN_EPSILON,
track_running_stats=tracking_stat)
self.bn_train = args.wsbn_train # store the bn train of not.
if self.bn_train:
self.bn.train()
else:
self.bn.eval()
# for dynamic channel
self.channel_drop = ChannelDropout(args.channel_dropout_method, args.channel_dropout_dropouto)
self.output_size = full_output_size
self.current_outsize = full_output_size # may change according to different value.
self.current_insize = full_input_size