本文整理汇总了Python中torch.nn.modules.batchnorm._BatchNorm方法的典型用法代码示例。如果您正苦于以下问题:Python batchnorm._BatchNorm方法的具体用法?Python batchnorm._BatchNorm怎么用?Python batchnorm._BatchNorm使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.nn.modules.batchnorm
的用法示例。
在下文中一共展示了batchnorm._BatchNorm方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: init_weights
# 需要导入模块: from torch.nn.modules import batchnorm [as 别名]
# 或者: from torch.nn.modules.batchnorm import _BatchNorm [as 别名]
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
示例2: init_weights
# 需要导入模块: from torch.nn.modules import batchnorm [as 别名]
# 或者: from torch.nn.modules.batchnorm import _BatchNorm [as 别名]
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) and hasattr(
m, 'conv2_offset'):
constant_init(m.conv2_offset, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
示例3: init_weights
# 需要导入模块: from torch.nn.modules import batchnorm [as 别名]
# 或者: from torch.nn.modules.batchnorm import _BatchNorm [as 别名]
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
示例4: __init__
# 需要导入模块: from torch.nn.modules import batchnorm [as 别名]
# 或者: from torch.nn.modules.batchnorm import _BatchNorm [as 别名]
def __init__(self, c, k, stage_num=3):
super(EMAU, self).__init__()
self.stage_num = stage_num
mu = torch.Tensor(1, c, k)
mu.normal_(0, math.sqrt(2. / k)) # Init with Kaiming Norm.
mu = self._l2norm(mu, dim=1)
self.register_buffer('mu', mu)
self.conv1 = nn.Conv2d(c, c, 1)
self.conv2 = nn.Sequential(
nn.Conv2d(c, c, 1, bias=False),
norm_layer(c))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, _BatchNorm):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.zero_()
示例5: group_weight
# 需要导入模块: from torch.nn.modules import batchnorm [as 别名]
# 或者: from torch.nn.modules.batchnorm import _BatchNorm [as 别名]
def group_weight(module):
group_decay = []
group_no_decay = []
for m in module.modules():
if isinstance(m, nn.Linear):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, Conv2d):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, _BatchNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, GroupNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
return group_decay, group_no_decay