本文簡要介紹python語言中 torch.jit.script
的用法。
用法:
torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)
obj(可調用,類,或
nn.Module
) -nn.Module
、函數、類類型、字典或要編譯的列表。example_inputs(聯盟[List[元組],字典[可調用,List[元組]],None]) - 提供示例輸入來注釋函數的參數或
nn.Module
.
如果
obj
是nn.Module
,則script
返回一個ScriptModule
對象。返回的ScriptModule
將具有與原始nn.Module
相同的一組 sub-modules 和參數。如果obj
是獨立函數,則將返回ScriptFunction
。如果obj
是dict
,則script
返回torch._C.ScriptDict
的實例。如果obj
是list
,則script
返回torch._C.ScriptList
的實例。編寫函數或
nn.Module
腳本將檢查源代碼,使用 TorchScript 編譯器將其編譯為 TorchScript 代碼,並返回ScriptModule
或ScriptFunction
。 TorchScript 本身是 Python 語言的子集,因此並非 Python 中的所有函數都有效,但我們提供了足夠的函數來計算張量並執行 control-dependent 操作。有關完整指南,請參閱TorchScript 語言參考。編寫字典或列表腳本會將其中的數據複製到 TorchScript 實例中,隨後可以通過引用在 Python 和 TorchScript 之間傳遞,且複製開銷為零。
torch.jit.script
可用作模塊、函數、字典和列表的函數並作為 TorchScript 類和函數的裝飾器
@torch.jit.script
。
- 編寫函數腳本
@torch.jit.script
裝飾器將通過編譯函數體構造一個ScriptFunction
。示例(編寫函數腳本):
import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2))
- **使用example_inputs編寫函數腳本#
示例輸入可用於注釋函數參數。
示例(在編寫腳本之前注釋函數):
import torch def test_sum(a, b): return a + b # Annotate the arguments to be int scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) print(type(scripted_fn)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(scripted_fn.code) # Call the function using the TorchScript interpreter scripted_fn(20, 100)
- 編寫 nn.Module 腳本
默認情況下,編寫
nn.Module
腳本將編譯forward
方法並遞歸編譯forward
調用的任何方法、子模塊和函數。如果nn.Module
僅使用 TorchScript 支持的函數,則無需更改原始模塊代碼。script
將構造ScriptModule
,它具有原始模塊的屬性、參數和方法的副本。示例(使用參數編寫簡單模塊的腳本):
import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3))
示例(使用跟蹤的子模塊編寫模塊腳本):
import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule())
要編譯
forward
以外的方法(並遞歸編譯它調用的任何內容),請將@torch.jit.export
裝飾器添加到方法中。要退出編譯,請使用@torch.jit.ignore
或@torch.jit.unused
。示例(模塊中的導出和忽略方法):
import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2)))
示例(使用 example_inputs 注釋 nn.Module 的前向):
import torch import torch.nn as nn from typing import NamedTuple class MyModule(NamedTuple): result: List[int] class TestNNModule(torch.nn.Module): def forward(self, a) -> MyModule: result = MyModule(result=a) return result pdt_model = TestNNModule() # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) # Run the scripted_model with actual inputs print(scripted_model([20]))
參數:
返回:
相關用法
- Python PyTorch scatter_object_list用法及代碼示例
- Python PyTorch saved_tensors_hooks用法及代碼示例
- Python PyTorch sqrt用法及代碼示例
- Python PyTorch skippable用法及代碼示例
- Python PyTorch squeeze用法及代碼示例
- Python PyTorch square用法及代碼示例
- Python PyTorch save_on_cpu用法及代碼示例
- Python PyTorch skip_init用法及代碼示例
- Python PyTorch simple_space_split用法及代碼示例
- Python PyTorch sum用法及代碼示例
- Python PyTorch sub用法及代碼示例
- Python PyTorch sparse_csr_tensor用法及代碼示例
- Python PyTorch sentencepiece_numericalizer用法及代碼示例
- Python PyTorch symeig用法及代碼示例
- Python PyTorch sinh用法及代碼示例
- Python PyTorch sinc用法及代碼示例
- Python PyTorch std_mean用法及代碼示例
- Python PyTorch spectral_norm用法及代碼示例
- Python PyTorch slogdet用法及代碼示例
- Python PyTorch symbolic_trace用法及代碼示例
- Python PyTorch shutdown用法及代碼示例
- Python PyTorch sgn用法及代碼示例
- Python PyTorch set_flush_denormal用法及代碼示例
- Python PyTorch set_default_dtype用法及代碼示例
- Python PyTorch signbit用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.script。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。