本文簡要介紹python語言中 torch.nn.DataParallel
的用法。
用法:
class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
module(torch.nn.Module) -要並行化的模塊
device_ids(Python列表:int或者torch.device) -CUDA 設備(默認:所有設備)
output_device(int或者torch.device) -輸出的設備位置(默認值:device_ids[0])
~DataParallel.module(torch.nn.Module) -要並行化的模塊
在模塊級別實現數據並行。
該容器並行化給定
module
的應用程序,方法是通過在批處理維度中分塊將輸入拆分到指定的設備(其他對象將在每個設備上複製一次)。在前向傳遞中,模塊在每個設備上複製,每個副本處理一部分輸入。在向後傳遞期間,來自每個副本的梯度被匯總到原始模塊中。批量大小應大於使用的 GPU 數量。
警告
建議使用
DistributedDataParallel
代替此類來進行multi-GPU訓練,即使隻有一個節點。請參閱:使用 nn.parallel.DistributedDataParallel 而不是多處理或 nn.DataParallel 和分布式數據並行。允許將任意位置和關鍵字輸入傳遞到DataParallel,但某些類型會經過特殊處理。張量將分散在指定的dim(默認為 0)上。 tuple、list 和 dict 類型將被淺拷貝。其他類型將在不同線程之間共享,並且如果寫入模型的正向傳遞中可能會損壞。
在運行
DataParallel
模塊之前,並行化的module
必須在device_ids[0]
上擁有其參數和緩衝區。警告
在每一個前鋒中,
module
是複製的在每個設備上,因此對正在運行的模塊的任何更新forward
會迷路。例如,如果module
有一個計數器屬性,在每個遞增forward
,它將始終保持在初始值,因為更新是在之後銷毀的副本上完成的forward
.然而,DataParallel
保證副本上device[0]
將使其參數和緩衝區與基本並行化共享存儲module
.所以到位更新參數或緩衝區device[0]
將被記錄。例如:,torch.nn.BatchNorm2d和torch.nn.utils.spectral_norm依靠這種行為來更新緩衝區。警告
module
及其子模塊上定義的前向和後向鉤子將被調用len(device_ids)
次,每次的輸入都位於特定設備上。特別地,鉤子僅保證相對於相應設備上的操作以正確的順序執行。例如,不能保證通過register_forward_pre_hook()
設置的鉤子在all
len(device_ids)
forward()
調用之前執行,但每個這樣的鉤子都在該設備的相應forward()
調用之前執行。警告
當
module
在forward()
中返回標量(即 0 維張量)時,此包裝器將返回一個長度等於數據並行中使用的設備數量的向量,其中包含來自每個設備的結果。注意
在
DataParallel
包的Module
中使用pack sequence -> recurrent network -> unpack sequence
模式有一個微妙之處。有關詳細信息,請參閱常見問題解答中的“我的循環網絡無法使用數據並行性”部分。例子:
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU
參數:
變量:
相關用法
- Python PyTorch DataFrameMaker用法及代碼示例
- Python PyTorch DatasetFolder.find_classes用法及代碼示例
- Python PyTorch DeQuantize用法及代碼示例
- Python PyTorch DistributedModelParallel用法及代碼示例
- Python PyTorch DistributedDataParallel用法及代碼示例
- Python PyTorch DenseArch用法及代碼示例
- Python PyTorch DeepFM用法及代碼示例
- Python PyTorch DistributedDataParallel.register_comm_hook用法及代碼示例
- Python PyTorch DLRM用法及代碼示例
- Python PyTorch DistributedSampler用法及代碼示例
- Python PyTorch DistributedDataParallel.join用法及代碼示例
- Python PyTorch Dropout用法及代碼示例
- Python PyTorch DistributedModelParallel.named_parameters用法及代碼示例
- Python PyTorch Dropout3d用法及代碼示例
- Python PyTorch DistributedModelParallel.state_dict用法及代碼示例
- Python PyTorch DistributedDataParallel.no_sync用法及代碼示例
- Python PyTorch Decompressor用法及代碼示例
- Python PyTorch Dropout2d用法及代碼示例
- Python PyTorch DistributedModelParallel.named_buffers用法及代碼示例
- Python PyTorch DeepFM.forward用法及代碼示例
- Python PyTorch Dirichlet用法及代碼示例
- Python PyTorch Demultiplexer用法及代碼示例
- Python PyTorch DistributedOptimizer用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.DataParallel。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。