当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch RemoteModule用法及代码示例


本文简要介绍python语言中 torch.distributed.nn.api.remote_module.RemoteModule 的用法。

用法:

class torch.distributed.nn.api.remote_module.RemoteModule(remote_device, module_cls, args=None, kwargs=None)

参数

  • remote_device(str) -我们要放置此模块的目标工作人员上的设备。格式应为“<workername>/<device>”,其中设备字段可以解析为torch.device类型。例如,“trainer0/cpu”、“trainer0”、“ps0/cuda:0”。另外,设备字段可以是可选的,默认值为“cpu”。

  • module_cls(torch.nn.Module) -

    要远程创建的模块的类。例如,

    >>> class MyModule(nn.Module):
    >>>     def forward(input):
    >>>         return input + 1
    >>>
    >>> module_cls = MyModule
  • args(Sequence,可选的) -要传递给 module_cls 的参数。

  • kwargs(字典,可选的) -kwargs 被传递给 module_cls

返回

包装由用户提供的 module_cls 创建的 Module 的远程模块实例,它具有阻塞 forward 方法和异步 forward_async 方法,该方法返回用户提供的 forward 调用的未来模块在远端。

RemoteModule 实例只能在 RPC 初始化后创建。它在指定的远程节点上创建用户指定的模块。它的行为类似于常规的 nn.Module,只不过 forward 方法是在远程节点上执行的。它负责自动梯度记录,以确保向后传递将梯度传播回相应的远程模块。

它基于 module_clsforward 方法的签名生成两个方法 forward_asyncforwardforward_async 异步运行并返回一个 Future。 forward_asyncforward 的参数与 module_cls 返回的模块的 forward 方法相同。

例如,如果 module_cls 返回 nn.Linear 的实例,该实例具有 forward 方法签名: def forward(input: Tensor) -> Tensor: ,则生成的 RemoteModule 将具有 2 个带有签名的方法:

def forward(input: Tensor) -> Tensor:
def forward_async(input: Tensor) -> Future[Tensor]:

例子:

在两个不同的进程中运行以下代码:

>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> from torch import nn, Tensor
>>> from torch.distributed.nn.api.remote_module import RemoteModule
>>>
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> remote_linear_module = RemoteModule(
>>>     "worker1/cpu", nn.Linear, args=(20, 30),
>>> )
>>> input = torch.randn(128, 20)
>>> ret_fut = remote_linear_module.forward_async(input)
>>> ret = ret_fut.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>>
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()

此外,与DistributedDataParallel(DDP)结合的更实际的例子可以在这个tutorial中找到。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.nn.api.remote_module.RemoteModule。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。