当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch SyncBatchNorm.convert_sync_batchnorm用法及代码示例


本文简要介绍python语言中 torch.nn.SyncBatchNorm.convert_sync_batchnorm 的用法。

用法:

classmethod convert_sync_batchnorm(module, process_group=None)

参数

  • module(torch.nn.Module) -包含一个或多个 BatchNorm*D 层的模块

  • process_group(可选的) -进程组范围同步,默认是整个世界

返回

原始module 与转换后的torch.nn.SyncBatchNorm 层。如果原始moduleBatchNorm*D图层,则将返回新的torch.nn.SyncBatchNorm图层对象。

将模型中的所有 BatchNorm*D 层转换为 torch.nn.SyncBatchNorm 层的辅助函数。

例子:

>>> # Network with nn.BatchNorm layer
>>> module = torch.nn.Sequential(
>>>            torch.nn.Linear(20, 100),
>>>            torch.nn.BatchNorm1d(100),
>>>          ).cuda()
>>> # creating process group (optional)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.SyncBatchNorm.convert_sync_batchnorm。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。