当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch DistributedModelParallel用法及代码示例


本文简要介绍python语言中 torchrec.distributed.model_parallel.DistributedModelParallel 的用法。

用法:

class torchrec.distributed.model_parallel.DistributedModelParallel(module: torch.nn.modules.module.Module, env: Optional[torchrec.distributed.types.ShardingEnv] = None, device: Optional[torch.device] = None, plan: Optional[torchrec.distributed.types.ShardingPlan] = None, sharders: Optional[List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[torchrec.distributed.model_parallel.DataParallelWrapper] = None)

参数

  • module(nn.Module) -要包装的模块。

  • env(可选的[ShardingEnv]) -具有进程组的分片环境。

  • device(可选的[torch.device]) -计算设备,默认为 cpu。

  • plan(可选的[ShardingPlan]) -计划在分片时使用,默认为 EmbeddingShardingPlanner.collective_plan()

  • sharders(可选的[List[ModuleSharder[nn.Module]]]) -ModuleSharders可用于分片,默认为EmbeddingBagCollectionSharder(),

  • init_data_parallel(bool) -data-parallel 模块可以是惰性的,即它们将参数初始化延迟到第一次前向传递。传递True 以延迟数据并行模块的初始化。首先进行前向传递,然后调用 DistributedModelParallel.init_data_parallel()。

  • init_parameters(bool) -初始化仍在元设备上的模块的参数。

  • data_parallel_wrapper(可选的[DataParallelWrapper]) -数据并行模块的自定义包装器。

基础:torch.nn.modules.module.Module torchrec.optim.fused.FusedOptimizerModule

模型并行性的入口点。

例子:

@torch.no_grad()
def init_weights(m):
    if isinstance(m, nn.Linear):
        m.weight.fill_(1.0)
    elif isinstance(m, EmbeddingBagCollection):
        for param in m.parameters():
            init.kaiming_normal_(param)

m = MyModel(device='meta')
m = DistributedModelParallel(m)
m.apply(init_weights)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.distributed.model_parallel.DistributedModelParallel。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。