當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch DataParallel用法及代碼示例


本文簡要介紹python語言中 torch.nn.DataParallel 的用法。

用法:

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)

參數

  • module(torch.nn.Module) -要並行化的模塊

  • device_ids(Python列表:int或者torch.device) -CUDA 設備(默認:所有設備)

  • output_device(int或者torch.device) -輸出的設備位置(默認值:device_ids[0])

變量

~DataParallel.module(torch.nn.Module) -要並行化的模塊

在模塊級別實現數據並行。

該容器並行化給定module 的應用程序,方法是通過在批處理維度中分塊將輸入拆分到指定的設備(其他對象將在每個設備上複製一次)。在前向傳遞中,模塊在每個設備上複製,每個副本處理一部分輸入。在向後傳遞期間,來自每個副本的梯度被匯總到原始模塊中。

批量大小應大於使用的 GPU 數量。

警告

建議使用 DistributedDataParallel 代替此類來進行multi-GPU訓練,即使隻有一個節點。請參閱:使用 nn.parallel.DistributedDataParallel 而不是多處理或 nn.DataParallel 和分布式數據並行。

允許將任意位置和關鍵字輸入傳遞到DataParallel,但某些類型會經過特殊處理。張量將分散在指定的dim(默認為 0)上。 tuple、list 和 dict 類型將被淺拷貝。其他類型將在不同線程之間共享,並且如果寫入模型的正向傳遞中可能會損壞。

在運行 DataParallel 模塊之前,並行化的 module 必須在 device_ids[0] 上擁有其參數和緩衝區。

警告

在每一個前鋒中,module複製的在每個設備上,因此對正在運行的模塊的任何更新forward會迷路。例如,如果module有一個計數器屬性,在每個遞增forward,它將始終保持在初始值,因為更新是在之後銷毀的副本上完成的forward.然而,DataParallel保證副本上device[0]將使其參數和緩衝區與基本並行化共享存儲module.所以到位更新參數或緩衝區device[0]將被記錄。例如:,torch.nn.BatchNorm2dtorch.nn.utils.spectral_norm依靠這種行為來更新緩衝區。

警告

module 及其子模塊上定義的前向和後向鉤子將被調用 len(device_ids) 次,每次的輸入都位於特定設備上。特別地,鉤子僅保證相對於相應設備上的操作以正確的順序執行。例如,不能保證通過 register_forward_pre_hook() 設置的鉤子在 all len(device_ids) forward() 調用之前執行,但每個這樣的鉤子都在該設備的相應 forward() 調用之前執行。

警告

moduleforward() 中返回標量(即 0 維張量)時,此包裝器將返回一個長度等於數據並行中使用的設備數量的向量,其中包含來自每個設備的結果。

注意

DataParallel 包的 Module 中使用 pack sequence -> recurrent network -> unpack sequence 模式有一個微妙之處。有關詳細信息,請參閱常見問題解答中的“我的循環網絡無法使用數據並行性”部分。

例子:

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.DataParallel。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。