當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch SyncBatchNorm.convert_sync_batchnorm用法及代碼示例

本文簡要介紹python語言中 torch.nn.SyncBatchNorm.convert_sync_batchnorm 的用法。

用法:

classmethod convert_sync_batchnorm(module, process_group=None)

參數

  • module(torch.nn.Module) -包含一個或多個 BatchNorm*D 層的模塊

  • process_group(可選的) -進程組範圍同步,默認是整個世界

返回

原始module 與轉換後的torch.nn.SyncBatchNorm 層。如果原始moduleBatchNorm*D圖層,則將返回新的torch.nn.SyncBatchNorm圖層對象。

將模型中的所有 BatchNorm*D 層轉換為 torch.nn.SyncBatchNorm 層的輔助函數。

例子:

>>> # Network with nn.BatchNorm layer
>>> module = torch.nn.Sequential(
>>>            torch.nn.Linear(20, 100),
>>>            torch.nn.BatchNorm1d(100),
>>>          ).cuda()
>>> # creating process group (optional)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.SyncBatchNorm.convert_sync_batchnorm。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。