当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch fork用法及代码示例


本文简要介绍python语言中 torch.jit.fork 的用法。

用法:

torch.jit.fork(func, *args, **kwargs)

参数

  • func(可调用的或者torch.nn.Module) -将被调用的 Python 函数或 torch.nn.Module。如果在 TorchScript 中执行,它将异步执行,否则不会。跟踪的 fork 调用将在 IR 中捕获。

  • *args-调用 func 的参数。

  • **kwargs-调用 func 的参数。

返回

func 执行的引用。值 T 只能通过强制完成 functorch.jit.wait 来访问。

返回类型

torch.jit.Future[T]

创建一个执行 func 的异步任务和对此执行结果的值的引用。 fork 将立即返回,因此func 的返回值可能尚未计算。要强制完成任务并访问返回值,请在 Future 上调用 torch.jit.wait。使用返回 Tfunc 调用的 fork 被键入为 torch.jit.Future[T]fork 调用可以任意嵌套,并且可以使用位置和关键字参数调用。只有在 TorchScript 中运行时才会发生异步执行。如果在纯 python 中运行,fork 将不会并行执行。 fork 在跟踪时调用时也不会并行执行,但是 forkwait 调用将在导出的 IR 图中捕获。

警告

fork 任务将不确定地执行。我们建议只为不修改其输入、模块属性或全局状态的纯函数生成并行 fork 任务。

示例(fork 一个自由函数):

import torch
from torch import Tensor
def foo(a : Tensor, b : int) -> Tensor:
    return a + b
def bar(a):
    fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
    return torch.jit.wait(fut)
script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# only the scripted version executes asynchronously
assert script_bar(input) == bar(input)
# trace is not run asynchronously, but fork is captured in IR
graph = torch.jit.trace(bar, (input,)).graph
assert "fork" in str(graph)

示例(fork 一个模块方法):

import torch
from torch import Tensor
class AddMod(torch.nn.Module):
    def forward(self, a: Tensor, b : int):
        return a + b
class Mod(torch.nn.Module):
    def __init__(self):
        super(self).__init__()
        self.mod = AddMod()
    def forward(self, input):
        fut = torch.jit.fork(self.mod, a, b=2)
        return torch.jit.wait(fut)
input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.fork。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。