本文简要介绍python语言中 torchrec.distributed.model_parallel.DistributedModelParallel
的用法。
用法:
class torchrec.distributed.model_parallel.DistributedModelParallel(module: torch.nn.modules.module.Module, env: Optional[torchrec.distributed.types.ShardingEnv] = None, device: Optional[torch.device] = None, plan: Optional[torchrec.distributed.types.ShardingPlan] = None, sharders: Optional[List[torchrec.distributed.types.ModuleSharder[torch.nn.modules.module.Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[torchrec.distributed.model_parallel.DataParallelWrapper] = None)
module(nn.Module) -要包装的模块。
env(可选的[ShardingEnv]) -具有进程组的分片环境。
device(可选的[torch.device]) -计算设备,默认为 cpu。
plan(可选的[ShardingPlan]) -计划在分片时使用,默认为
EmbeddingShardingPlanner.collective_plan()
。sharders(可选的[List[ModuleSharder[nn.Module]]]) -
ModuleSharders
可用于分片,默认为EmbeddingBagCollectionSharder()
,init_data_parallel(bool) -data-parallel 模块可以是惰性的,即它们将参数初始化延迟到第一次前向传递。传递
True
以延迟数据并行模块的初始化。首先进行前向传递,然后调用 DistributedModelParallel.init_data_parallel()。init_parameters(bool) -初始化仍在元设备上的模块的参数。
data_parallel_wrapper(可选的[DataParallelWrapper]) -数据并行模块的自定义包装器。
基础:
torch.nn.modules.module.Module
,torchrec.optim.fused.FusedOptimizerModule
模型并行性的入口点。
例子:
@torch.no_grad() def init_weights(m): if isinstance(m, nn.Linear): m.weight.fill_(1.0) elif isinstance(m, EmbeddingBagCollection): for param in m.parameters(): init.kaiming_normal_(param) m = MyModel(device='meta') m = DistributedModelParallel(m) m.apply(init_weights)
参数:
相关用法
- Python PyTorch DistributedModelParallel.named_parameters用法及代码示例
- Python PyTorch DistributedModelParallel.state_dict用法及代码示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代码示例
- Python PyTorch DistributedDataParallel用法及代码示例
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代码示例
- Python PyTorch DistributedSampler用法及代码示例
- Python PyTorch DistributedDataParallel.join用法及代码示例
- Python PyTorch DistributedDataParallel.no_sync用法及代码示例
- Python PyTorch DistributedOptimizer用法及代码示例
- Python PyTorch Dirichlet用法及代码示例
- Python PyTorch DeQuantize用法及代码示例
- Python PyTorch DenseArch用法及代码示例
- Python PyTorch DeepFM用法及代码示例
- Python PyTorch DataFrameMaker用法及代码示例
- Python PyTorch DLRM用法及代码示例
- Python PyTorch Dropout用法及代码示例
- Python PyTorch Dropout3d用法及代码示例
- Python PyTorch DataParallel用法及代码示例
- Python PyTorch Decompressor用法及代码示例
- Python PyTorch Dropout2d用法及代码示例
- Python PyTorch DeepFM.forward用法及代码示例
- Python PyTorch Demultiplexer用法及代码示例
- Python PyTorch DatasetFolder.find_classes用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.distributed.model_parallel.DistributedModelParallel。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。