本文整理汇总了Python中apex.parallel方法的典型用法代码示例。如果您正苦于以下问题:Python apex.parallel方法的具体用法?Python apex.parallel怎么用?Python apex.parallel使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类apex
的用法示例。
在下文中一共展示了apex.parallel方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: check_ddp_wrapped
# 需要导入模块: import apex [as 别名]
# 或者: from apex import parallel [as 别名]
def check_ddp_wrapped(model: nn.Module) -> bool:
"""
Checks whether model is wrapped with DataParallel/DistributedDataParallel.
"""
parallel_wrappers = nn.DataParallel, nn.parallel.DistributedDataParallel
# Check whether Apex is installed and if it is,
# add Apex's DistributedDataParallel to list of checked types
try:
from apex.parallel import DistributedDataParallel as apex_DDP
parallel_wrappers = parallel_wrappers + (apex_DDP,)
except ImportError:
pass
return isinstance(model, parallel_wrappers)
示例2: main
# 需要导入模块: import apex [as 别名]
# 或者: from apex import parallel [as 别名]
def main(cfgs):
Logger.init(**cfgs['logger'])
local_rank = cfgs['local_rank']
world_size = int(os.environ['WORLD_SIZE'])
Log.info('rank: {}, world_size: {}'.format(local_rank, world_size))
log_dir = cfgs['log_dir']
pth_dir = cfgs['pth_dir']
if local_rank == 0:
assure_dir(log_dir)
assure_dir(pth_dir)
aux_config = cfgs.get('auxiliary', None)
network = ModuleBuilder(cfgs['network'], aux_config).cuda()
criterion = build_criterion(cfgs['criterion'], aux_config).cuda()
optimizer = optim.SGD(network.parameters(), **cfgs['optimizer'])
scheduler = PolyLRScheduler(optimizer, **cfgs['scheduler'])
dataset = build_dataset(**cfgs['dataset'], **cfgs['transforms'])
sampler = DistributedSampler4Iter(dataset, world_size=world_size,
rank=local_rank, **cfgs['sampler'])
train_loader = DataLoader(dataset, sampler=sampler, **cfgs['loader'])
cudnn.benchmark = True
torch.manual_seed(666)
torch.cuda.manual_seed(666)
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl', init_method='env://')
model = DistributedDataParallel(network)
model = apex.parallel.convert_syncbn_model(model)
torch.cuda.empty_cache()
train(local_rank, world_size, pth_dir, cfgs['frequency'], criterion,
train_loader, model, optimizer, scheduler)
示例3: get_nn_from_ddp_module
# 需要导入模块: import apex [as 别名]
# 或者: from apex import parallel [as 别名]
def get_nn_from_ddp_module(model: nn.Module) -> nn.Module:
"""
Return a real model from a torch.nn.DataParallel,
torch.nn.parallel.DistributedDataParallel, or
apex.parallel.DistributedDataParallel.
Args:
model: A model, or DataParallel wrapper.
Returns:
A model
"""
if check_ddp_wrapped(model):
model = model.module
return model