当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch script用法及代码示例


本文简要介绍python语言中 torch.jit.script 的用法。

用法:

torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)

参数

  • obj(可调用,类,或nn.Module) -nn.Module 、函数、类类型、字典或要编译的列表。

  • example_inputs(联盟[List[元组],字典[可调用,List[元组]],None]) - 提供示例输入来注释函数的参数或nn.Module.

返回

如果 objnn.Module ,则 script 返回一个 ScriptModule 对象。返回的 ScriptModule 将具有与原始 nn.Module 相同的一组 sub-modules 和参数。如果 obj 是独立函数,则将返回 ScriptFunction 。如果 objdict ,则 script 返回 torch._C.ScriptDict 的实例。如果 objlist ,则 script 返回 torch._C.ScriptList 的实例。

编写函数或 nn.Module 脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码,并返回 ScriptModule ScriptFunction 。 TorchScript 本身是 Python 语言的子集,因此并非 Python 中的所有函数都有效,但我们提供了足够的函数来计算张量并执行 control-dependent 操作。有关完整指南,请参阅TorchScript 语言参考。

编写字典或列表脚本会将其中的数据复制到 TorchScript 实例中,随后可以通过引用在 Python 和 TorchScript 之间传递,且复制开销为零。

torch.jit.script 可用作模块、函数、字典和列表的函数

并作为 TorchScript 类和函数的装饰器 @torch.jit.script

编写函数脚本

@torch.jit.script 装饰器将通过编译函数体构造一个 ScriptFunction

示例(编写函数脚本):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(foo.code)

# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
**使用example_inputs编写函数脚本#

示例输入可用于注释函数参数。

示例(在编写脚本之前注释函数):

import torch

def test_sum(a, b):
    return a + b

# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])

print(type(scripted_fn))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(scripted_fn.code)

# Call the function using the TorchScript interpreter
scripted_fn(20, 100)
编写 nn.Module 脚本

默认情况下,编写 nn.Module 脚本将编译 forward 方法并递归编译 forward 调用的任何方法、子模块和函数。如果 nn.Module 仅使用 TorchScript 支持的函数,则无需更改原始模块代码。 script 将构造 ScriptModule ,它具有原始模块的属性、参数和方法的副本。

示例(使用参数编写简单模块的脚本):

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))

示例(使用跟踪的子模块编写模块脚本):

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

要编译 forward 以外的方法(并递归编译它调用的任何内容),请将 @torch.jit.export 装饰器添加到方法中。要退出编译,请使用 @torch.jit.ignore @torch.jit.unused

示例(模块中的导出和忽略方法):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb
        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99

scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

示例(使用 example_inputs 注释 nn.Module 的前向):

import torch
import torch.nn as nn
from typing import NamedTuple

class MyModule(NamedTuple):
result: List[int]

class TestNNModule(torch.nn.Module):
    def forward(self, a) -> MyModule:
        result = MyModule(result=a)
        return result

pdt_model = TestNNModule()

# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })

# Run the scripted_model with actual inputs
print(scripted_model([20]))

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.script。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。