當前位置: 首頁>>代碼示例>>Python>>正文


Python apex.parallel方法代碼示例

本文整理匯總了Python中apex.parallel方法的典型用法代碼示例。如果您正苦於以下問題:Python apex.parallel方法的具體用法?Python apex.parallel怎麽用?Python apex.parallel使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在apex的用法示例。


在下文中一共展示了apex.parallel方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: check_ddp_wrapped

# 需要導入模塊: import apex [as 別名]
# 或者: from apex import parallel [as 別名]
def check_ddp_wrapped(model: nn.Module) -> bool:
    """
    Checks whether model is wrapped with DataParallel/DistributedDataParallel.
    """
    parallel_wrappers = nn.DataParallel, nn.parallel.DistributedDataParallel

    # Check whether Apex is installed and if it is,
    # add Apex's DistributedDataParallel to list of checked types
    try:
        from apex.parallel import DistributedDataParallel as apex_DDP

        parallel_wrappers = parallel_wrappers + (apex_DDP,)
    except ImportError:
        pass

    return isinstance(model, parallel_wrappers) 
開發者ID:catalyst-team,項目名稱:catalyst,代碼行數:18,代碼來源:distributed.py

示例2: main

# 需要導入模塊: import apex [as 別名]
# 或者: from apex import parallel [as 別名]
def main(cfgs):
    Logger.init(**cfgs['logger'])

    local_rank = cfgs['local_rank']
    world_size = int(os.environ['WORLD_SIZE'])
    Log.info('rank: {}, world_size: {}'.format(local_rank, world_size))

    log_dir = cfgs['log_dir']
    pth_dir = cfgs['pth_dir']
    if local_rank == 0:
        assure_dir(log_dir)
        assure_dir(pth_dir)

    aux_config = cfgs.get('auxiliary', None)
    network = ModuleBuilder(cfgs['network'], aux_config).cuda()
    criterion = build_criterion(cfgs['criterion'], aux_config).cuda()
    optimizer = optim.SGD(network.parameters(), **cfgs['optimizer'])
    scheduler = PolyLRScheduler(optimizer, **cfgs['scheduler'])

    dataset = build_dataset(**cfgs['dataset'], **cfgs['transforms'])
    sampler = DistributedSampler4Iter(dataset, world_size=world_size, 
                                      rank=local_rank, **cfgs['sampler'])
    train_loader = DataLoader(dataset, sampler=sampler, **cfgs['loader'])

    cudnn.benchmark = True
    torch.manual_seed(666)
    torch.cuda.manual_seed(666)
    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')

    model = DistributedDataParallel(network)
    model = apex.parallel.convert_syncbn_model(model)

    torch.cuda.empty_cache()
    train(local_rank, world_size, pth_dir, cfgs['frequency'], criterion, 
          train_loader, model, optimizer, scheduler) 
開發者ID:cv-benchmarks,項目名稱:pkuseg,代碼行數:38,代碼來源:train.py

示例3: get_nn_from_ddp_module

# 需要導入模塊: import apex [as 別名]
# 或者: from apex import parallel [as 別名]
def get_nn_from_ddp_module(model: nn.Module) -> nn.Module:
    """
    Return a real model from a torch.nn.DataParallel,
    torch.nn.parallel.DistributedDataParallel, or
    apex.parallel.DistributedDataParallel.

    Args:
        model: A model, or DataParallel wrapper.

    Returns:
        A model
    """
    if check_ddp_wrapped(model):
        model = model.module
    return model 
開發者ID:catalyst-team,項目名稱:catalyst,代碼行數:17,代碼來源:distributed.py


注:本文中的apex.parallel方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。