本文简要介绍python语言中 torch.jit.script
的用法。
用法:
torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)
obj(可调用,类,或
nn.Module
) -nn.Module
、函数、类类型、字典或要编译的列表。example_inputs(联盟[List[元组],字典[可调用,List[元组]],None]) - 提供示例输入来注释函数的参数或
nn.Module
.
如果
obj
是nn.Module
,则script
返回一个ScriptModule
对象。返回的ScriptModule
将具有与原始nn.Module
相同的一组 sub-modules 和参数。如果obj
是独立函数,则将返回ScriptFunction
。如果obj
是dict
,则script
返回torch._C.ScriptDict
的实例。如果obj
是list
,则script
返回torch._C.ScriptList
的实例。编写函数或
nn.Module
脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码,并返回ScriptModule
或ScriptFunction
。 TorchScript 本身是 Python 语言的子集,因此并非 Python 中的所有函数都有效,但我们提供了足够的函数来计算张量并执行 control-dependent 操作。有关完整指南,请参阅TorchScript 语言参考。编写字典或列表脚本会将其中的数据复制到 TorchScript 实例中,随后可以通过引用在 Python 和 TorchScript 之间传递,且复制开销为零。
torch.jit.script
可用作模块、函数、字典和列表的函数并作为 TorchScript 类和函数的装饰器
@torch.jit.script
。
- 编写函数脚本
@torch.jit.script
装饰器将通过编译函数体构造一个ScriptFunction
。示例(编写函数脚本):
import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2))
- **使用example_inputs编写函数脚本#
示例输入可用于注释函数参数。
示例(在编写脚本之前注释函数):
import torch def test_sum(a, b): return a + b # Annotate the arguments to be int scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) print(type(scripted_fn)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(scripted_fn.code) # Call the function using the TorchScript interpreter scripted_fn(20, 100)
- 编写 nn.Module 脚本
默认情况下,编写
nn.Module
脚本将编译forward
方法并递归编译forward
调用的任何方法、子模块和函数。如果nn.Module
仅使用 TorchScript 支持的函数,则无需更改原始模块代码。script
将构造ScriptModule
,它具有原始模块的属性、参数和方法的副本。示例(使用参数编写简单模块的脚本):
import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3))
示例(使用跟踪的子模块编写模块脚本):
import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule())
要编译
forward
以外的方法(并递归编译它调用的任何内容),请将@torch.jit.export
装饰器添加到方法中。要退出编译,请使用@torch.jit.ignore
或@torch.jit.unused
。示例(模块中的导出和忽略方法):
import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2)))
示例(使用 example_inputs 注释 nn.Module 的前向):
import torch import torch.nn as nn from typing import NamedTuple class MyModule(NamedTuple): result: List[int] class TestNNModule(torch.nn.Module): def forward(self, a) -> MyModule: result = MyModule(result=a) return result pdt_model = TestNNModule() # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) # Run the scripted_model with actual inputs print(scripted_model([20]))
参数:
返回:
相关用法
- Python PyTorch scatter_object_list用法及代码示例
- Python PyTorch saved_tensors_hooks用法及代码示例
- Python PyTorch sqrt用法及代码示例
- Python PyTorch skippable用法及代码示例
- Python PyTorch squeeze用法及代码示例
- Python PyTorch square用法及代码示例
- Python PyTorch save_on_cpu用法及代码示例
- Python PyTorch skip_init用法及代码示例
- Python PyTorch simple_space_split用法及代码示例
- Python PyTorch sum用法及代码示例
- Python PyTorch sub用法及代码示例
- Python PyTorch sparse_csr_tensor用法及代码示例
- Python PyTorch sentencepiece_numericalizer用法及代码示例
- Python PyTorch symeig用法及代码示例
- Python PyTorch sinh用法及代码示例
- Python PyTorch sinc用法及代码示例
- Python PyTorch std_mean用法及代码示例
- Python PyTorch spectral_norm用法及代码示例
- Python PyTorch slogdet用法及代码示例
- Python PyTorch symbolic_trace用法及代码示例
- Python PyTorch shutdown用法及代码示例
- Python PyTorch sgn用法及代码示例
- Python PyTorch set_flush_denormal用法及代码示例
- Python PyTorch set_default_dtype用法及代码示例
- Python PyTorch signbit用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.script。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。