當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch trace用法及代碼示例


本文簡要介紹python語言中 torch.jit.trace 的用法。

用法:

torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>)

參數

  • func(可調用的或者torch.nn.Module) -Python 函數或 torch.nn.Module 將與 example_inputs 一起運行。 func 參數和返回值必須是張量或(可能是嵌套的)包含張量的元組。當模塊通過 torch.jit.trace 時,僅運行和跟蹤 forward 方法(有關詳細信息,請參閱 torch.jit.trace )。

  • example_inputs(tuple或者torch.Tensor) -在跟蹤時將傳遞給函數的示例輸入元組。假設跟蹤的操作支持這些類型和形狀,則可以使用不同類型和形狀的輸入運行生成的跟蹤。 example_inputs 也可以是單個張量,在這種情況下,它會自動包裝在一個元組中。

關鍵字參數

  • check_trace(bool, 可選的) -檢查通過跟蹤代碼運行的相同輸入是否產生相同的輸出。默認值:True。例如,如果您的網絡包含非確定性操作,或者您確定網絡是正確的(盡管檢查器失敗),您可能希望禁用此函數。

  • check_inputs(元組列表,可選的) -輸入參數的元組列表,用於檢查跟蹤是否符合預期。每個元組等效於將在 example_inputs 中指定的一組輸入參數。為了獲得最佳結果,請傳入一組檢查輸入,這些輸入代表您希望網絡看到的形狀和輸入類型的空間。如果未指定,則使用原始example_inputs 進行檢查

  • check_tolerance(float,可選的) -在檢查程序中使用的浮點比較容差。如果由於已知原因(例如運算符融合)導致結果在數值上出現分歧,這可以用來放寬檢查器的嚴格性。

  • strict(bool, 可選的) -是否在嚴格模式下運行跟蹤器(默認值:True)。僅當您希望跟蹤器記錄您的可變容器類型(當前為 list /dict )並且您確定您在問題中使用的容器是 constant 結構並且不會被用作控製流(if,for)條件。

返回

如果 funcnn.Modulenn.Moduleforward ,則 trace 返回一個 ScriptModule 對象,其中包含一個包含跟蹤代碼的 forward 方法。返回的 ScriptModule 將具有與原始 nn.Module 相同的一組 sub-modules 和參數。如果 func 是獨立函數,則 trace 返回 ScriptFunction

跟蹤函數並返回將使用just-in-time 編譯優化的可執行文件或 ScriptFunction 。對於僅在 Tensor 和列表、字典和 Tensor 的元組上運行的代碼,跟蹤是理想的。

使用 torch.jit.tracetorch.jit.trace_module ,您可以將現有模塊或 Python 函數轉換為 TorchScript ScriptFunction ScriptModule 。您必須提供示例輸入,然後我們運行該函數,記錄對所有張量執行的操作。

  • 獨立函數的結果記錄產生 ScriptFunction

  • nn.Module.forwardnn.Module 的結果記錄產生 ScriptModule

該模塊還包含原始模塊所具有的任何參數。

警告

跟蹤僅正確記錄不依賴數據的函數和模塊(例如,張量中的數據沒有條件)並且沒有任何未跟蹤的外部依賴項(例如,執行輸入/輸出或訪問全局變量)。跟蹤僅記錄在給定張量上運行給定函數時完成的操作。因此,返回的ScriptModule 將始終在任何輸入上運行相同的跟蹤圖。當您的模塊需要根據輸入和/或模塊狀態運行不同的操作集時,這會產生一些重要的影響。例如,

  • 跟蹤不會記錄任何 control-flow,如 if-statements 或循環。當這個 control-flow 在您的模塊中保持不變時,這很好,它通常會內聯 control-flow 決策。但有時 control-flow 實際上是模型本身的一部分。例如,循環網絡是輸入序列(可能是動態的)長度上的循環。

  • 在返回的 ScriptModule 中,無論 ScriptModule 處於哪種模式,在 trainingeval 模式下具有不同行為的操作將始終像在跟蹤期間所處的模式一樣。

在這種情況下,跟蹤將不合適, scripting 是更好的選擇。如果您跟蹤此類模型,您可能會在後續調用模型時默默地得到不正確的結果。當執行可能導致生成錯誤跟蹤的操作時,跟蹤器將嘗試發出警告。

示例(跟蹤函數):

import torch

def foo(x, y):
    return 2 * x + y

# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment

示例(跟蹤現有模塊):

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.trace。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。