当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch trace用法及代码示例


本文简要介绍python语言中 torch.jit.trace 的用法。

用法:

torch.jit.trace(func, example_inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>)

参数

  • func(可调用的或者torch.nn.Module) -Python 函数或 torch.nn.Module 将与 example_inputs 一起运行。 func 参数和返回值必须是张量或(可能是嵌套的)包含张量的元组。当模块通过 torch.jit.trace 时,仅运行和跟踪 forward 方法(有关详细信息,请参阅 torch.jit.trace )。

  • example_inputs(tuple或者torch.Tensor) -在跟踪时将传递给函数的示例输入元组。假设跟踪的操作支持这些类型和形状,则可以使用不同类型和形状的输入运行生成的跟踪。 example_inputs 也可以是单个张量,在这种情况下,它会自动包装在一个元组中。

关键字参数

  • check_trace(bool, 可选的) -检查通过跟踪代码运行的相同输入是否产生相同的输出。默认值:True。例如,如果您的网络包含非确定性操作,或者您确定网络是正确的(尽管检查器失败),您可能希望禁用此函数。

  • check_inputs(元组列表,可选的) -输入参数的元组列表,用于检查跟踪是否符合预期。每个元组等效于将在 example_inputs 中指定的一组输入参数。为了获得最佳结果,请传入一组检查输入,这些输入代表您希望网络看到的形状和输入类型的空间。如果未指定,则使用原始example_inputs 进行检查

  • check_tolerance(float,可选的) -在检查程序中使用的浮点比较容差。如果由于已知原因(例如运算符融合)导致结果在数值上出现分歧,这可以用来放宽检查器的严格性。

  • strict(bool, 可选的) -是否在严格模式下运行跟踪器(默认值:True)。仅当您希望跟踪器记录您的可变容器类型(当前为 list /dict )并且您确定您在问题中使用的容器是 constant 结构并且不会被用作控制流(if,for)条件。

返回

如果 funcnn.Modulenn.Moduleforward ,则 trace 返回一个 ScriptModule 对象,其中包含一个包含跟踪代码的 forward 方法。返回的 ScriptModule 将具有与原始 nn.Module 相同的一组 sub-modules 和参数。如果 func 是独立函数,则 trace 返回 ScriptFunction

跟踪函数并返回将使用just-in-time 编译优化的可执行文件或 ScriptFunction 。对于仅在 Tensor 和列表、字典和 Tensor 的元组上运行的代码,跟踪是理想的。

使用 torch.jit.tracetorch.jit.trace_module ,您可以将现有模块或 Python 函数转换为 TorchScript ScriptFunction ScriptModule 。您必须提供示例输入,然后我们运行该函数,记录对所有张量执行的操作。

  • 独立函数的结果记录产生 ScriptFunction

  • nn.Module.forwardnn.Module 的结果记录产生 ScriptModule

该模块还包含原始模块所具有的任何参数。

警告

跟踪仅正确记录不依赖数据的函数和模块(例如,张量中的数据没有条件)并且没有任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)。跟踪仅记录在给定张量上运行给定函数时完成的操作。因此,返回的ScriptModule 将始终在任何输入上运行相同的跟踪图。当您的模块需要根据输入和/或模块状态运行不同的操作集时,这会产生一些重要的影响。例如,

  • 跟踪不会记录任何 control-flow,如 if-statements 或循环。当这个 control-flow 在您的模块中保持不变时,这很好,它通常会内联 control-flow 决策。但有时 control-flow 实际上是模型本身的一部分。例如,循环网络是输入序列(可能是动态的)长度上的循环。

  • 在返回的 ScriptModule 中,无论 ScriptModule 处于哪种模式,在 trainingeval 模式下具有不同行为的操作将始终像在跟踪期间所处的模式一样。

在这种情况下,跟踪将不合适, scripting 是更好的选择。如果您跟踪此类模型,您可能会在后续调用模型时默默地得到不正确的结果。当执行可能导致生成错误跟踪的操作时,跟踪器将尝试发出警告。

示例(跟踪函数):

import torch

def foo(x, y):
    return 2 * x + y

# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment

示例(跟踪现有模块):

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.jit.trace。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。