本文簡要介紹python語言中 torch.jit.trace
的用法。
用法:
torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>)
func(可調用的或者torch.nn.Module) -Python 函數或
torch.nn.Module
將與example_inputs
一起運行。func
參數和返回值必須是張量或(可能是嵌套的)包含張量的元組。當模塊通過torch.jit.trace
時,僅運行和跟蹤forward
方法(有關詳細信息,請參閱torch.jit.trace
)。example_inputs(tuple或者torch.Tensor) -在跟蹤時將傳遞給函數的示例輸入元組。假設跟蹤的操作支持這些類型和形狀,則可以使用不同類型和形狀的輸入運行生成的跟蹤。
example_inputs
也可以是單個張量,在這種情況下,它會自動包裝在一個元組中。
check_trace(
bool
, 可選的) -檢查通過跟蹤代碼運行的相同輸入是否產生相同的輸出。默認值:True
。例如,如果您的網絡包含非確定性操作,或者您確定網絡是正確的(盡管檢查器失敗),您可能希望禁用此函數。check_inputs(元組列表,可選的) -輸入參數的元組列表,用於檢查跟蹤是否符合預期。每個元組等效於將在
example_inputs
中指定的一組輸入參數。為了獲得最佳結果,請傳入一組檢查輸入,這些輸入代表您希望網絡看到的形狀和輸入類型的空間。如果未指定,則使用原始example_inputs
進行檢查check_tolerance(float,可選的) -在檢查程序中使用的浮點比較容差。如果由於已知原因(例如運算符融合)導致結果在數值上出現分歧,這可以用來放寬檢查器的嚴格性。
strict(
bool
, 可選的) -是否在嚴格模式下運行跟蹤器(默認值:True
)。僅當您希望跟蹤器記錄您的可變容器類型(當前為list
/dict
)並且您確定您在問題中使用的容器是constant
結構並且不會被用作控製流(if,for)條件。
如果
func
是nn.Module
或nn.Module
的forward
,則trace
返回一個ScriptModule
對象,其中包含一個包含跟蹤代碼的forward
方法。返回的ScriptModule
將具有與原始nn.Module
相同的一組 sub-modules 和參數。如果func
是獨立函數,則trace
返回ScriptFunction
。跟蹤函數並返回將使用just-in-time 編譯優化的可執行文件或
ScriptFunction
。對於僅在Tensor
和列表、字典和Tensor
的元組上運行的代碼,跟蹤是理想的。使用
torch.jit.trace
和torch.jit.trace_module
,您可以將現有模塊或 Python 函數轉換為 TorchScriptScriptFunction
或ScriptModule
。您必須提供示例輸入,然後我們運行該函數,記錄對所有張量執行的操作。獨立函數的結果記錄產生
ScriptFunction
。nn.Module.forward
或nn.Module
的結果記錄產生ScriptModule
。
該模塊還包含原始模塊所具有的任何參數。
警告
跟蹤僅正確記錄不依賴數據的函數和模塊(例如,張量中的數據沒有條件)並且沒有任何未跟蹤的外部依賴項(例如,執行輸入/輸出或訪問全局變量)。跟蹤僅記錄在給定張量上運行給定函數時完成的操作。因此,返回的
ScriptModule
將始終在任何輸入上運行相同的跟蹤圖。當您的模塊需要根據輸入和/或模塊狀態運行不同的操作集時,這會產生一些重要的影響。例如,跟蹤不會記錄任何 control-flow,如 if-statements 或循環。當這個 control-flow 在您的模塊中保持不變時,這很好,它通常會內聯 control-flow 決策。但有時 control-flow 實際上是模型本身的一部分。例如,循環網絡是輸入序列(可能是動態的)長度上的循環。
在返回的
ScriptModule
中,無論ScriptModule
處於哪種模式,在training
和eval
模式下具有不同行為的操作將始終像在跟蹤期間所處的模式一樣。
在這種情況下,跟蹤將不合適,
scripting
是更好的選擇。如果您跟蹤此類模型,您可能會在後續調用模型時默默地得到不正確的結果。當執行可能導致生成錯誤跟蹤的操作時,跟蹤器將嘗試發出警告。示例(跟蹤函數):
import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment
示例(跟蹤現有模塊):
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input)
參數:
關鍵字參數:
返回:
相關用法
- Python PyTorch trace_module用法及代碼示例
- Python PyTorch trace用法及代碼示例
- Python PyTorch transpose用法及代碼示例
- Python PyTorch trapezoid用法及代碼示例
- Python PyTorch trunc用法及代碼示例
- Python PyTorch triu_indices用法及代碼示例
- Python PyTorch triangular_solve用法及代碼示例
- Python PyTorch tril_indices用法及代碼示例
- Python PyTorch tril用法及代碼示例
- Python PyTorch triu用法及代碼示例
- Python PyTorch tensorinv用法及代碼示例
- Python PyTorch tensor用法及代碼示例
- Python PyTorch to_map_style_dataset用法及代碼示例
- Python PyTorch topk用法及代碼示例
- Python PyTorch tensorsolve用法及代碼示例
- Python PyTorch tile用法及代碼示例
- Python PyTorch tanh用法及代碼示例
- Python PyTorch take_along_dim用法及代碼示例
- Python PyTorch tensor_split用法及代碼示例
- Python PyTorch t用法及代碼示例
- Python PyTorch take用法及代碼示例
- Python PyTorch tensordot用法及代碼示例
- Python PyTorch tan用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.trace。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。