當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch fork用法及代碼示例


本文簡要介紹python語言中 torch.jit.fork 的用法。

用法:

torch.jit.fork(func, *args, **kwargs)

參數

  • func(可調用的或者torch.nn.Module) -將被調用的 Python 函數或 torch.nn.Module。如果在 TorchScript 中執行,它將異步執行,否則不會。跟蹤的 fork 調用將在 IR 中捕獲。

  • *args-調用 func 的參數。

  • **kwargs-調用 func 的參數。

返回

func 執行的引用。值 T 隻能通過強製完成 functorch.jit.wait 來訪問。

返回類型

torch.jit.Future[T]

創建一個執行 func 的異步任務和對此執行結果的值的引用。 fork 將立即返回,因此func 的返回值可能尚未計算。要強製完成任務並訪問返回值,請在 Future 上調用 torch.jit.wait。使用返回 Tfunc 調用的 fork 被鍵入為 torch.jit.Future[T]fork 調用可以任意嵌套,並且可以使用位置和關鍵字參數調用。隻有在 TorchScript 中運行時才會發生異步執行。如果在純 python 中運行,fork 將不會並行執行。 fork 在跟蹤時調用時也不會並行執行,但是 forkwait 調用將在導出的 IR 圖中捕獲。

警告

fork 任務將不確定地執行。我們建議隻為不修改其輸入、模塊屬性或全局狀態的純函數生成並行 fork 任務。

示例(fork 一個自由函數):

import torch
from torch import Tensor
def foo(a : Tensor, b : int) -> Tensor:
    return a + b
def bar(a):
    fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
    return torch.jit.wait(fut)
script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# only the scripted version executes asynchronously
assert script_bar(input) == bar(input)
# trace is not run asynchronously, but fork is captured in IR
graph = torch.jit.trace(bar, (input,)).graph
assert "fork" in str(graph)

示例(fork 一個模塊方法):

import torch
from torch import Tensor
class AddMod(torch.nn.Module):
    def forward(self, a: Tensor, b : int):
        return a + b
class Mod(torch.nn.Module):
    def __init__(self):
        super(self).__init__()
        self.mod = AddMod()
    def forward(self, input):
        fut = torch.jit.fork(self.mod, a, b=2)
        return torch.jit.wait(fut)
input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.fork。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。