本文简要介绍python语言中 torch.jit.trace_module
的用法。
用法:
torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>)
mod(torch.nn.Module) -
torch.nn.Module
包含名称在inputs
中指定的方法。给定的方法将被编译为单个ScriptModule
的一部分。inputs(dict) -包含由
mod
中的方法名称索引的示例输入的字典。输入将在跟踪时传递给名称对应于输入键的方法。{ 'forward' : example_forward_input, 'method2': example_method2_input}
check_trace(
bool
, 可选的) -检查通过跟踪代码运行的相同输入是否产生相同的输出。默认值:True
。例如,如果您的网络包含非确定性操作,或者您确定网络是正确的(尽管检查器失败),您可能希望禁用此函数。check_inputs(字典列表,可选的) -输入参数的字典列表,应用于检查跟踪是否符合预期。每个元组等效于将在
inputs
中指定的一组输入参数。为了获得最佳结果,请传入一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。如果未指定,则使用原始inputs
进行检查check_tolerance(float,可选的) -在检查程序中使用的浮点比较容差。如果由于已知原因(例如运算符融合)导致结果在数值上出现分歧,这可以用来放宽检查器的严格性。
一个
ScriptModule
对象,带有一个包含跟踪代码的forward
方法。当func
是torch.nn.Module
时,返回的ScriptModule
将具有与func
相同的一组 sub-modules 和参数。跟踪一个模块并返回一个可执行文件
ScriptModule
,它将使用just-in-time 编译进行优化。将模块传递给torch.jit.trace
时,仅运行和跟踪forward
方法。使用trace_module
,您可以将方法名称字典指定为示例输入以跟踪以下参数(请参阅inputs
)。有关跟踪的更多信息,请参阅
torch.jit.trace
。示例(使用多种方法跟踪模块):
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) def weighted_kernel_sum(self, weight): return weight * self.conv.weight n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} module = torch.jit.trace_module(n, inputs)
参数:
关键字参数:
返回:
相关用法
- Python PyTorch trace用法及代码示例
- Python PyTorch transpose用法及代码示例
- Python PyTorch trapezoid用法及代码示例
- Python PyTorch trunc用法及代码示例
- Python PyTorch triu_indices用法及代码示例
- Python PyTorch triangular_solve用法及代码示例
- Python PyTorch tril_indices用法及代码示例
- Python PyTorch tril用法及代码示例
- Python PyTorch triu用法及代码示例
- Python PyTorch tensorinv用法及代码示例
- Python PyTorch tensor用法及代码示例
- Python PyTorch to_map_style_dataset用法及代码示例
- Python PyTorch topk用法及代码示例
- Python PyTorch tensorsolve用法及代码示例
- Python PyTorch tile用法及代码示例
- Python PyTorch tanh用法及代码示例
- Python PyTorch take_along_dim用法及代码示例
- Python PyTorch tensor_split用法及代码示例
- Python PyTorch t用法及代码示例
- Python PyTorch take用法及代码示例
- Python PyTorch tensordot用法及代码示例
- Python PyTorch tan用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
- Python PyTorch cholesky用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.trace_module。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。