当前位置: 首页>>代码示例>>Python>>正文


Python apex.parallel方法代码示例

本文整理汇总了Python中apex.parallel方法的典型用法代码示例。如果您正苦于以下问题:Python apex.parallel方法的具体用法?Python apex.parallel怎么用?Python apex.parallel使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在apex的用法示例。


在下文中一共展示了apex.parallel方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: check_ddp_wrapped

# 需要导入模块: import apex [as 别名]
# 或者: from apex import parallel [as 别名]
def check_ddp_wrapped(model: nn.Module) -> bool:
    """
    Checks whether model is wrapped with DataParallel/DistributedDataParallel.
    """
    parallel_wrappers = nn.DataParallel, nn.parallel.DistributedDataParallel

    # Check whether Apex is installed and if it is,
    # add Apex's DistributedDataParallel to list of checked types
    try:
        from apex.parallel import DistributedDataParallel as apex_DDP

        parallel_wrappers = parallel_wrappers + (apex_DDP,)
    except ImportError:
        pass

    return isinstance(model, parallel_wrappers) 
开发者ID:catalyst-team,项目名称:catalyst,代码行数:18,代码来源:distributed.py

示例2: main

# 需要导入模块: import apex [as 别名]
# 或者: from apex import parallel [as 别名]
def main(cfgs):
    Logger.init(**cfgs['logger'])

    local_rank = cfgs['local_rank']
    world_size = int(os.environ['WORLD_SIZE'])
    Log.info('rank: {}, world_size: {}'.format(local_rank, world_size))

    log_dir = cfgs['log_dir']
    pth_dir = cfgs['pth_dir']
    if local_rank == 0:
        assure_dir(log_dir)
        assure_dir(pth_dir)

    aux_config = cfgs.get('auxiliary', None)
    network = ModuleBuilder(cfgs['network'], aux_config).cuda()
    criterion = build_criterion(cfgs['criterion'], aux_config).cuda()
    optimizer = optim.SGD(network.parameters(), **cfgs['optimizer'])
    scheduler = PolyLRScheduler(optimizer, **cfgs['scheduler'])

    dataset = build_dataset(**cfgs['dataset'], **cfgs['transforms'])
    sampler = DistributedSampler4Iter(dataset, world_size=world_size, 
                                      rank=local_rank, **cfgs['sampler'])
    train_loader = DataLoader(dataset, sampler=sampler, **cfgs['loader'])

    cudnn.benchmark = True
    torch.manual_seed(666)
    torch.cuda.manual_seed(666)
    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')

    model = DistributedDataParallel(network)
    model = apex.parallel.convert_syncbn_model(model)

    torch.cuda.empty_cache()
    train(local_rank, world_size, pth_dir, cfgs['frequency'], criterion, 
          train_loader, model, optimizer, scheduler) 
开发者ID:cv-benchmarks,项目名称:pkuseg,代码行数:38,代码来源:train.py

示例3: get_nn_from_ddp_module

# 需要导入模块: import apex [as 别名]
# 或者: from apex import parallel [as 别名]
def get_nn_from_ddp_module(model: nn.Module) -> nn.Module:
    """
    Return a real model from a torch.nn.DataParallel,
    torch.nn.parallel.DistributedDataParallel, or
    apex.parallel.DistributedDataParallel.

    Args:
        model: A model, or DataParallel wrapper.

    Returns:
        A model
    """
    if check_ddp_wrapped(model):
        model = model.module
    return model 
开发者ID:catalyst-team,项目名称:catalyst,代码行数:17,代码来源:distributed.py


注:本文中的apex.parallel方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。