本文簡要介紹python語言中 torch.jit.trace_module
的用法。
用法:
torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>)
mod(torch.nn.Module) -
torch.nn.Module
包含名稱在inputs
中指定的方法。給定的方法將被編譯為單個ScriptModule
的一部分。inputs(dict) -包含由
mod
中的方法名稱索引的示例輸入的字典。輸入將在跟蹤時傳遞給名稱對應於輸入鍵的方法。{ 'forward' : example_forward_input, 'method2': example_method2_input}
check_trace(
bool
, 可選的) -檢查通過跟蹤代碼運行的相同輸入是否產生相同的輸出。默認值:True
。例如,如果您的網絡包含非確定性操作,或者您確定網絡是正確的(盡管檢查器失敗),您可能希望禁用此函數。check_inputs(字典列表,可選的) -輸入參數的字典列表,應用於檢查跟蹤是否符合預期。每個元組等效於將在
inputs
中指定的一組輸入參數。為了獲得最佳結果,請傳入一組檢查輸入,這些輸入代表您希望網絡看到的形狀和輸入類型的空間。如果未指定,則使用原始inputs
進行檢查check_tolerance(float,可選的) -在檢查程序中使用的浮點比較容差。如果由於已知原因(例如運算符融合)導致結果在數值上出現分歧,這可以用來放寬檢查器的嚴格性。
一個
ScriptModule
對象,帶有一個包含跟蹤代碼的forward
方法。當func
是torch.nn.Module
時,返回的ScriptModule
將具有與func
相同的一組 sub-modules 和參數。跟蹤一個模塊並返回一個可執行文件
ScriptModule
,它將使用just-in-time 編譯進行優化。將模塊傳遞給torch.jit.trace
時,僅運行和跟蹤forward
方法。使用trace_module
,您可以將方法名稱字典指定為示例輸入以跟蹤以下參數(請參閱inputs
)。有關跟蹤的更多信息,請參閱
torch.jit.trace
。示例(使用多種方法跟蹤模塊):
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) def weighted_kernel_sum(self, weight): return weight * self.conv.weight n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} module = torch.jit.trace_module(n, inputs)
參數:
關鍵字參數:
返回:
相關用法
- Python PyTorch trace用法及代碼示例
- Python PyTorch transpose用法及代碼示例
- Python PyTorch trapezoid用法及代碼示例
- Python PyTorch trunc用法及代碼示例
- Python PyTorch triu_indices用法及代碼示例
- Python PyTorch triangular_solve用法及代碼示例
- Python PyTorch tril_indices用法及代碼示例
- Python PyTorch tril用法及代碼示例
- Python PyTorch triu用法及代碼示例
- Python PyTorch tensorinv用法及代碼示例
- Python PyTorch tensor用法及代碼示例
- Python PyTorch to_map_style_dataset用法及代碼示例
- Python PyTorch topk用法及代碼示例
- Python PyTorch tensorsolve用法及代碼示例
- Python PyTorch tile用法及代碼示例
- Python PyTorch tanh用法及代碼示例
- Python PyTorch take_along_dim用法及代碼示例
- Python PyTorch tensor_split用法及代碼示例
- Python PyTorch t用法及代碼示例
- Python PyTorch take用法及代碼示例
- Python PyTorch tensordot用法及代碼示例
- Python PyTorch tan用法及代碼示例
- Python PyTorch frexp用法及代碼示例
- Python PyTorch jvp用法及代碼示例
- Python PyTorch cholesky用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.jit.trace_module。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。