當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch RemoteModule用法及代碼示例


本文簡要介紹python語言中 torch.distributed.nn.api.remote_module.RemoteModule 的用法。

用法:

class torch.distributed.nn.api.remote_module.RemoteModule(remote_device, module_cls, args=None, kwargs=None)

參數

  • remote_device(str) -我們要放置此模塊的目標工作人員上的設備。格式應為“<workername>/<device>”,其中設備字段可以解析為torch.device類型。例如,“trainer0/cpu”、“trainer0”、“ps0/cuda:0”。另外,設備字段可以是可選的,默認值為“cpu”。

  • module_cls(torch.nn.Module) -

    要遠程創建的模塊的類。例如,

    >>> class MyModule(nn.Module):
    >>>     def forward(input):
    >>>         return input + 1
    >>>
    >>> module_cls = MyModule
  • args(Sequence,可選的) -要傳遞給 module_cls 的參數。

  • kwargs(字典,可選的) -kwargs 被傳遞給 module_cls

返回

包裝由用戶提供的 module_cls 創建的 Module 的遠程模塊實例,它具有阻塞 forward 方法和異步 forward_async 方法,該方法返回用戶提供的 forward 調用的未來模塊在遠端。

RemoteModule 實例隻能在 RPC 初始化後創建。它在指定的遠程節點上創建用戶指定的模塊。它的行為類似於常規的 nn.Module,隻不過 forward 方法是在遠程節點上執行的。它負責自動梯度記錄,以確保向後傳遞將梯度傳播回相應的遠程模塊。

它基於 module_clsforward 方法的簽名生成兩個方法 forward_asyncforwardforward_async 異步運行並返回一個 Future。 forward_asyncforward 的參數與 module_cls 返回的模塊的 forward 方法相同。

例如,如果 module_cls 返回 nn.Linear 的實例,該實例具有 forward 方法簽名: def forward(input: Tensor) -> Tensor: ,則生成的 RemoteModule 將具有 2 個帶有簽名的方法:

def forward(input: Tensor) -> Tensor:
def forward_async(input: Tensor) -> Future[Tensor]:

例子:

在兩個不同的進程中運行以下代碼:

>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> from torch import nn, Tensor
>>> from torch.distributed.nn.api.remote_module import RemoteModule
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>>     "worker1/cpu", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
>>> ret = ret_fut.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>>
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()

此外,與DistributedDataParallel(DDP)結合的更實際的例子可以在這個tutorial中找到。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.nn.api.remote_module.RemoteModule。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。