本文簡要介紹python語言中 torchrec.distributed.model_parallel.DistributedModelParallel
的用法。
用法:
class torchrec.distributed.model_parallel.DistributedModelParallel(module: torch.nn.modules.module.Module, env: Optional[torchrec.distributed.types.ShardingEnv] = None, device: Optional[torch.device] = None, plan: Optional[torchrec.distributed.types.ShardingPlan] = None, sharders: Optional[List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[torchrec.distributed.model_parallel.DataParallelWrapper] = None)
module(nn.Module) -要包裝的模塊。
env(可選的[ShardingEnv]) -具有進程組的分片環境。
device(可選的[torch.device]) -計算設備,默認為 cpu。
plan(可選的[ShardingPlan]) -計劃在分片時使用,默認為
EmbeddingShardingPlanner.collective_plan()
。sharders(可選的[List[ModuleSharder[nn.Module]]]) -
ModuleSharders
可用於分片,默認為EmbeddingBagCollectionSharder()
,init_data_parallel(bool) -data-parallel 模塊可以是惰性的,即它們將參數初始化延遲到第一次前向傳遞。傳遞
True
以延遲數據並行模塊的初始化。首先進行前向傳遞,然後調用 DistributedModelParallel.init_data_parallel()。init_parameters(bool) -初始化仍在元設備上的模塊的參數。
data_parallel_wrapper(可選的[DataParallelWrapper]) -數據並行模塊的自定義包裝器。
基礎:
torch.nn.modules.module.Module
,torchrec.optim.fused.FusedOptimizerModule
模型並行性的入口點。
例子:
@torch.no_grad() def init_weights(m): if isinstance(m, nn.Linear): m.weight.fill_(1.0) elif isinstance(m, EmbeddingBagCollection): for param in m.parameters(): init.kaiming_normal_(param) m = MyModel(device='meta') m = DistributedModelParallel(m) m.apply(init_weights)
參數:
相關用法
- Python PyTorch DistributedModelParallel.named_parameters用法及代碼示例
- Python PyTorch DistributedModelParallel.state_dict用法及代碼示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代碼示例
- Python PyTorch DistributedDataParallel用法及代碼示例
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代碼示例
- Python PyTorch DistributedSampler用法及代碼示例
- Python PyTorch DistributedDataParallel.join用法及代碼示例
- Python PyTorch DistributedDataParallel.no_sync用法及代碼示例
- Python PyTorch DistributedOptimizer用法及代碼示例
- Python PyTorch Dirichlet用法及代碼示例
- Python PyTorch DeQuantize用法及代碼示例
- Python PyTorch DenseArch用法及代碼示例
- Python PyTorch DeepFM用法及代碼示例
- Python PyTorch DataFrameMaker用法及代碼示例
- Python PyTorch DLRM用法及代碼示例
- Python PyTorch Dropout用法及代碼示例
- Python PyTorch Dropout3d用法及代碼示例
- Python PyTorch DataParallel用法及代碼示例
- Python PyTorch Decompressor用法及代碼示例
- Python PyTorch Dropout2d用法及代碼示例
- Python PyTorch DeepFM.forward用法及代碼示例
- Python PyTorch Demultiplexer用法及代碼示例
- Python PyTorch DatasetFolder.find_classes用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torchrec.distributed.model_parallel.DistributedModelParallel。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。