當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch script用法及代碼示例


本文簡要介紹python語言中 torch.jit.script 的用法。

用法:

torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)

參數

  • obj(可調用,類,或nn.Module) -nn.Module 、函數、類類型、字典或要編譯的列表。

  • example_inputs(聯盟[List[元組],字典[可調用,List[元組]],None]) - 提供示例輸入來注釋函數的參數或nn.Module.

返回

如果 objnn.Module ,則 script 返回一個 ScriptModule 對象。返回的 ScriptModule 將具有與原始 nn.Module 相同的一組 sub-modules 和參數。如果 obj 是獨立函數,則將返回 ScriptFunction 。如果 objdict ,則 script 返回 torch._C.ScriptDict 的實例。如果 objlist ,則 script 返回 torch._C.ScriptList 的實例。

編寫函數或 nn.Module 腳本將檢查源代碼,使用 TorchScript 編譯器將其編譯為 TorchScript 代碼,並返回 ScriptModule ScriptFunction 。 TorchScript 本身是 Python 語言的子集,因此並非 Python 中的所有函數都有效,但我們提供了足夠的函數來計算張量並執行 control-dependent 操作。有關完整指南,請參閱TorchScript 語言參考。

編寫字典或列表腳本會將其中的數據複製到 TorchScript 實例中,隨後可以通過引用在 Python 和 TorchScript 之間傳遞,且複製開銷為零。

torch.jit.script 可用作模塊、函數、字典和列表的函數

並作為 TorchScript 類和函數的裝飾器 @torch.jit.script

編寫函數腳本

@torch.jit.script 裝飾器將通過編譯函數體構造一個 ScriptFunction

示例(編寫函數腳本):

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(foo.code)

# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
**使用example_inputs編寫函數腳本#

示例輸入可用於注釋函數參數。

示例(在編寫腳本之前注釋函數):

import torch

def test_sum(a, b):
    return a + b

# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])

print(type(scripted_fn))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(scripted_fn.code)

# Call the function using the TorchScript interpreter
scripted_fn(20, 100)
編寫 nn.Module 腳本

默認情況下,編寫 nn.Module 腳本將編譯 forward 方法並遞歸編譯 forward 調用的任何方法、子模塊和函數。如果 nn.Module 僅使用 TorchScript 支持的函數,則無需更改原始模塊代碼。 script 將構造 ScriptModule ,它具有原始模塊的屬性、參數和方法的副本。

示例(使用參數編寫簡單模塊的腳本):

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))

示例(使用跟蹤的子模塊編寫模塊腳本):

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

要編譯 forward 以外的方法(並遞歸編譯它調用的任何內容),請將 @torch.jit.export 裝飾器添加到方法中。要退出編譯,請使用 @torch.jit.ignore @torch.jit.unused

示例(模塊中的導出和忽略方法):

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb
        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99

scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

示例(使用 example_inputs 注釋 nn.Module 的前向):

import torch
import torch.nn as nn
from typing import NamedTuple

class MyModule(NamedTuple):
result: List[int]

class TestNNModule(torch.nn.Module):
    def forward(self, a) -> MyModule:
        result = MyModule(result=a)
        return result

pdt_model = TestNNModule()

# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })

# Run the scripted_model with actual inputs
print(scripted_model([20]))

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.script。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。