本文整理匯總了Python中apex.parallel方法的典型用法代碼示例。如果您正苦於以下問題:Python apex.parallel方法的具體用法?Python apex.parallel怎麽用?Python apex.parallel使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類apex
的用法示例。
在下文中一共展示了apex.parallel方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: check_ddp_wrapped
# 需要導入模塊: import apex [as 別名]
# 或者: from apex import parallel [as 別名]
def check_ddp_wrapped(model: nn.Module) -> bool:
"""
Checks whether model is wrapped with DataParallel/DistributedDataParallel.
"""
parallel_wrappers = nn.DataParallel, nn.parallel.DistributedDataParallel
# Check whether Apex is installed and if it is,
# add Apex's DistributedDataParallel to list of checked types
try:
from apex.parallel import DistributedDataParallel as apex_DDP
parallel_wrappers = parallel_wrappers + (apex_DDP,)
except ImportError:
pass
return isinstance(model, parallel_wrappers)
示例2: main
# 需要導入模塊: import apex [as 別名]
# 或者: from apex import parallel [as 別名]
def main(cfgs):
Logger.init(**cfgs['logger'])
local_rank = cfgs['local_rank']
world_size = int(os.environ['WORLD_SIZE'])
Log.info('rank: {}, world_size: {}'.format(local_rank, world_size))
log_dir = cfgs['log_dir']
pth_dir = cfgs['pth_dir']
if local_rank == 0:
assure_dir(log_dir)
assure_dir(pth_dir)
aux_config = cfgs.get('auxiliary', None)
network = ModuleBuilder(cfgs['network'], aux_config).cuda()
criterion = build_criterion(cfgs['criterion'], aux_config).cuda()
optimizer = optim.SGD(network.parameters(), **cfgs['optimizer'])
scheduler = PolyLRScheduler(optimizer, **cfgs['scheduler'])
dataset = build_dataset(**cfgs['dataset'], **cfgs['transforms'])
sampler = DistributedSampler4Iter(dataset, world_size=world_size,
rank=local_rank, **cfgs['sampler'])
train_loader = DataLoader(dataset, sampler=sampler, **cfgs['loader'])
cudnn.benchmark = True
torch.manual_seed(666)
torch.cuda.manual_seed(666)
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
model = DistributedDataParallel(network)
model = apex.parallel.convert_syncbn_model(model)
torch.cuda.empty_cache()
train(local_rank, world_size, pth_dir, cfgs['frequency'], criterion,
train_loader, model, optimizer, scheduler)
示例3: get_nn_from_ddp_module
# 需要導入模塊: import apex [as 別名]
# 或者: from apex import parallel [as 別名]
def get_nn_from_ddp_module(model: nn.Module) -> nn.Module:
"""
Return a real model from a torch.nn.DataParallel,
torch.nn.parallel.DistributedDataParallel, or
apex.parallel.DistributedDataParallel.
Args:
model: A model, or DataParallel wrapper.
Returns:
A model
"""
if check_ddp_wrapped(model):
model = model.module
return model