Python PyTorch async_execution用法及代碼示例

本文簡要介紹python語言中 torch.distributed.rpc.functions.async_execution 的用法。



函數的裝飾器,指示函數的返回值保證是 Future 對象,並且該函數可以在 RPC 被調用者上異步運行。更具體地說,被調用者提取由包裝函數返回的 Future 並將後續處理步驟安裝為該 Future 的回調。完成後,安裝的回調將從 Future 讀取值,並將該值作為 RPC 響應發送返回。這也意味著返回的 Future 僅存在於被調用方,並且永遠不會通過RPC發送。當包裝函數(fn)的執行由於包含rpc_async()或等待其他信號而需要暫停和恢複時,此裝飾器非常有用。


要啟用異步執行,應用程序必須將此裝飾器返回的函數對象傳遞給 RPC API。如果 RPC 檢測到此裝飾器安裝的屬性,它知道此函數返回一個 Future 對象並將相應地處理它。然而,這並不意味著在定義一個函數時這個裝飾器必須是最外層的。例如,當與 @staticmethod@classmethod 結合使用時,@rpc.functions.async_execution 需要作為內部裝飾器,以允許將目標函數識別為靜態函數或類函數。此目標函數仍然可以異步執行,因為在訪問時,靜態或類方法會保留由 @rpc.functions.async_execution 安裝的屬性。


返回的 Future 對象可以來自 rpc_async() then() Future 構造函數。下麵的示例顯示直接使用 then() 返回的 Future

>>> from torch.distributed import rpc
>>> # omitting setup and shutdown RPC
>>> # On all workers
>>> @rpc.functions.async_execution
>>> def async_add_chained(to, x, y, z):
>>>     # This function runs on "worker1" and returns immediately when
>>>     # the callback is installed through the `then(cb)` API. In the
>>>     # mean time, the `rpc_async` to "worker2" can run concurrently.
>>>     # When the return value of that `rpc_async` arrives at
>>>     # "worker1", "worker1" will run the lambda function accordingly
>>>     # and set the value for the previously returned `Future`, which
>>>     # will then trigger RPC to send the result back to "worker0".
>>>     return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>>         lambda fut: fut.wait() + z
>>>     )
>>> # On worker0
>>> ret = rpc.rpc_sync(
>>>     "worker1",
>>>     async_add_chained,
>>>     args=("worker2", torch.ones(2), 1, 1)
>>> )
>>> print(ret)  # prints tensor([3., 3.])

當與TorchScript 裝飾器結合使用時,這個裝飾器必須是最外層的。

>>> from torch import Tensor
>>> from torch.futures import Future
>>> from torch.distributed import rpc
>>> # omitting setup and shutdown RPC
>>> # On all workers
>>> @torch.jit.script
>>> def script_add(x: Tensor, y: Tensor) -> Tensor:
>>>     return x + y
>>> @rpc.functions.async_execution
>>> @torch.jit.script
>>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
>>>     return rpc.rpc_async(to, script_add, (x, y))
>>> # On worker0
>>> ret = rpc.rpc_sync(
>>>     "worker1",
>>>     async_add,
>>>     args=("worker2", torch.ones(2), 1)
>>> )
>>> print(ret)  # prints tensor([2., 2.])


>>> from torch.distributed import rpc
>>> # omitting setup and shutdown RPC
>>> # On all workers
>>> class AsyncExecutionClass:
>>>     @staticmethod
>>>     @rpc.functions.async_execution
>>>     def static_async_add(to, x, y, z):
>>>         return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>>             lambda fut: fut.wait() + z
>>>         )
>>>     @classmethod
>>>     @rpc.functions.async_execution
>>>     def class_async_add(cls, to, x, y, z):
>>>         ret_fut = torch.futures.Future()
>>>         rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>>             lambda fut: ret_fut.set_result(fut.wait() + z)
>>>         )
>>>         return ret_fut
>>>     @rpc.functions.async_execution
>>>     def bound_async_add(self, to, x, y, z):
>>>         return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>>             lambda fut: fut.wait() + z
>>>         )
>>> # On worker0
>>> ret = rpc.rpc_sync(
>>>     "worker1",
>>>     AsyncExecutionClass.static_async_add,
>>>     args=("worker2", torch.ones(2), 1, 2)
>>> )
>>> print(ret)  # prints tensor([4., 4.])
>>> ret = rpc.rpc_sync(
>>>     "worker1",
>>>     AsyncExecutionClass.class_async_add,
>>>     args=("worker2", torch.ones(2), 1, 2)
>>> )
>>> print(ret)  # prints tensor([4., 4.])

該裝飾器還可以與 RRef 助手一起使用,即 . torch.distributed.rpc.RRef.rpc_sync()torch.distributed.rpc.RRef.rpc_async()torch.distributed.rpc.RRef.remote()

>>> from torch.distributed import rpc
>>> # reuse the AsyncExecutionClass class above
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
>>> print(ret)  # prints tensor([4., 4.])
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
>>> print(ret)  # prints tensor([4., 4.])
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
>>> print(ret)  # prints tensor([4., 4.])


