本文简要介绍python语言中 torch.jit.trace
的用法。
用法:
torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>)
func(可调用的或者torch.nn.Module) -Python 函数或
torch.nn.Module
将与example_inputs
一起运行。func
参数和返回值必须是张量或(可能是嵌套的)包含张量的元组。当模块通过torch.jit.trace
时,仅运行和跟踪forward
方法(有关详细信息,请参阅torch.jit.trace
)。example_inputs(tuple或者torch.Tensor) -在跟踪时将传递给函数的示例输入元组。假设跟踪的操作支持这些类型和形状,则可以使用不同类型和形状的输入运行生成的跟踪。
example_inputs
也可以是单个张量,在这种情况下,它会自动包装在一个元组中。
check_trace(
bool
, 可选的) -检查通过跟踪代码运行的相同输入是否产生相同的输出。默认值:True
。例如,如果您的网络包含非确定性操作,或者您确定网络是正确的(尽管检查器失败),您可能希望禁用此函数。check_inputs(元组列表,可选的) -输入参数的元组列表,用于检查跟踪是否符合预期。每个元组等效于将在
example_inputs
中指定的一组输入参数。为了获得最佳结果,请传入一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。如果未指定,则使用原始example_inputs
进行检查check_tolerance(float,可选的) -在检查程序中使用的浮点比较容差。如果由于已知原因(例如运算符融合)导致结果在数值上出现分歧,这可以用来放宽检查器的严格性。
strict(
bool
, 可选的) -是否在严格模式下运行跟踪器(默认值:True
)。仅当您希望跟踪器记录您的可变容器类型(当前为list
/dict
)并且您确定您在问题中使用的容器是constant
结构并且不会被用作控制流(if,for)条件。
如果
func
是nn.Module
或nn.Module
的forward
,则trace
返回一个ScriptModule
对象,其中包含一个包含跟踪代码的forward
方法。返回的ScriptModule
将具有与原始nn.Module
相同的一组 sub-modules 和参数。如果func
是独立函数,则trace
返回ScriptFunction
。跟踪函数并返回将使用just-in-time 编译优化的可执行文件或
ScriptFunction
。对于仅在Tensor
和列表、字典和Tensor
的元组上运行的代码,跟踪是理想的。使用
torch.jit.trace
和torch.jit.trace_module
,您可以将现有模块或 Python 函数转换为 TorchScriptScriptFunction
或ScriptModule
。您必须提供示例输入,然后我们运行该函数,记录对所有张量执行的操作。独立函数的结果记录产生
ScriptFunction
。nn.Module.forward
或nn.Module
的结果记录产生ScriptModule
。
该模块还包含原始模块所具有的任何参数。
警告
跟踪仅正确记录不依赖数据的函数和模块(例如,张量中的数据没有条件)并且没有任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)。跟踪仅记录在给定张量上运行给定函数时完成的操作。因此,返回的
ScriptModule
将始终在任何输入上运行相同的跟踪图。当您的模块需要根据输入和/或模块状态运行不同的操作集时,这会产生一些重要的影响。例如,跟踪不会记录任何 control-flow,如 if-statements 或循环。当这个 control-flow 在您的模块中保持不变时,这很好,它通常会内联 control-flow 决策。但有时 control-flow 实际上是模型本身的一部分。例如,循环网络是输入序列(可能是动态的)长度上的循环。
在返回的
ScriptModule
中,无论ScriptModule
处于哪种模式,在training
和eval
模式下具有不同行为的操作将始终像在跟踪期间所处的模式一样。
在这种情况下,跟踪将不合适,
scripting
是更好的选择。如果您跟踪此类模型,您可能会在后续调用模型时默默地得到不正确的结果。当执行可能导致生成错误跟踪的操作时,跟踪器将尝试发出警告。示例(跟踪函数):
import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment
示例(跟踪现有模块):
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input)
参数:
关键字参数:
返回:
相关用法
- Python PyTorch trace_module用法及代码示例
- Python PyTorch trace用法及代码示例
- Python PyTorch transpose用法及代码示例
- Python PyTorch trapezoid用法及代码示例
- Python PyTorch trunc用法及代码示例
- Python PyTorch triu_indices用法及代码示例
- Python PyTorch triangular_solve用法及代码示例
- Python PyTorch tril_indices用法及代码示例
- Python PyTorch tril用法及代码示例
- Python PyTorch triu用法及代码示例
- Python PyTorch tensorinv用法及代码示例
- Python PyTorch tensor用法及代码示例
- Python PyTorch to_map_style_dataset用法及代码示例
- Python PyTorch topk用法及代码示例
- Python PyTorch tensorsolve用法及代码示例
- Python PyTorch tile用法及代码示例
- Python PyTorch tanh用法及代码示例
- Python PyTorch take_along_dim用法及代码示例
- Python PyTorch tensor_split用法及代码示例
- Python PyTorch t用法及代码示例
- Python PyTorch take用法及代码示例
- Python PyTorch tensordot用法及代码示例
- Python PyTorch tan用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.trace。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。