當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch DistributedModelParallel用法及代碼示例


本文簡要介紹python語言中 torchrec.distributed.model_parallel.DistributedModelParallel 的用法。

用法:

class torchrec.distributed.model_parallel.DistributedModelParallel(module: torch.nn.modules.module.Module, env: Optional[torchrec.distributed.types.ShardingEnv] = None, device: Optional[torch.device] = None, plan: Optional[torchrec.distributed.types.ShardingPlan] = None, sharders: Optional[List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[torchrec.distributed.model_parallel.DataParallelWrapper] = None)

參數

  • module(nn.Module) -要包裝的模塊。

  • env(可選的[ShardingEnv]) -具有進程組的分片環境。

  • device(可選的[torch.device]) -計算設備,默認為 cpu。

  • plan(可選的[ShardingPlan]) -計劃在分片時使用,默認為 EmbeddingShardingPlanner.collective_plan()

  • sharders(可選的[List[ModuleSharder[nn.Module]]]) -ModuleSharders可用於分片,默認為EmbeddingBagCollectionSharder(),

  • init_data_parallel(bool) -data-parallel 模塊可以是惰性的,即它們將參數初始化延遲到第一次前向傳遞。傳遞True 以延遲數據並行模塊的初始化。首先進行前向傳遞,然後調用 DistributedModelParallel.init_data_parallel()。

  • init_parameters(bool) -初始化仍在元設備上的模塊的參數。

  • data_parallel_wrapper(可選的[DataParallelWrapper]) -數據並行模塊的自定義包裝器。

基礎:torch.nn.modules.module.Module torchrec.optim.fused.FusedOptimizerModule

模型並行性的入口點。

例子:

@torch.no_grad()
def init_weights(m):
    if isinstance(m, nn.Linear):
        m.weight.fill_(1.0)
    elif isinstance(m, EmbeddingBagCollection):
        for param in m.parameters():
            init.kaiming_normal_(param)

m = MyModel(device='meta')
m = DistributedModelParallel(m)
m.apply(init_weights)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.distributed.model_parallel.DistributedModelParallel。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。