当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch export用法及代码示例


本文简要介绍python语言中 torch.onnx.export 的用法。

用法:

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=None, opset_version=None, _retain_param_name=None, do_constant_folding=True, example_outputs=None, strip_doc_string=None, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=None, use_external_data_format=None)

参数

  • model(torch.nn.Module,torch.jit.ScriptModule或者torch.jit.ScriptFunction) -要导出的模型。

  • args(tuple或者torch.Tensor) -

    args 可以构造为:

    1. 只有一个元组的参数:

      args = (x, y, z)

    元组应包含模型输入,以便 model(*args) 是模型的有效调用。任何非张量参数都将 hard-coded 进入导出模型;任何张量参数都将成为导出模型的输入,按照它们在元组中出现的顺序。

    1. 张量:

      args = torch.Tensor([1])

    这相当于该张量的一元元组。

    1. 以命名参数字典结尾的参数元组:

      args = (x,
              {'y': input_y,
               'z': input_z})

    元组中除了最后一个元素之外的所有元素都将作为非关键字参数传递,命名参数将从最后一个元素开始设置。如果字典中不存在命名参数,则为其分配默认值,如果未提供默认值,则为 None。

    注意

    如果字典是 args 元组的最后一个元素,它将被解释为包含命名参数。为了将 dict 作为最后一个非关键字 arg 传递,请提供一个空 dict 作为 args 元组的最后一个元素。例如,而不是:

    torch.onnx.export(
        model,
        (x,
         # WRONG: will be interpreted as named arguments
         {y: z}),
        "test.onnx.pb")

    写:

    torch.onnx.export(
        model,
        (x,
         {y: z},
         {}),
        "test.onnx.pb")
  • f-file-like 对象(这样 f.fileno() 返回文件说明符)或包含文件名的字符串。二进制协议缓冲区将写入此文件。

  • export_params(bool,默认真) -如果为 True,则将导出所有参数。如果要导出未经训练的模型,请将其设置为 False。在这种情况下,导出的模型将首先将其所有参数作为参数,其顺序由 model.state_dict().values() 指定

  • verbose(bool,默认假) -如果为 True,则打印正在导出到标准输出的模型的说明。此外,最终的 ONNX 图将包括来自导出模型的字段 doc_string`,其中提到了 model 的源代码位置。

  • training(枚举,默认 TrainingMode.EVAL) -

    • TrainingMode.EVAL :以推理模式导出模型。

    • TrainingMode.PRESERVE :如果 model.training 为 False,则在推理模式下导出模型;如果 model.training 为 True,则在训练模式下导出模型。

    • TrainingMode.TRAINING:在训练模式下导出模型。禁用可能干扰训练的优化。

  • input_names(str 列表,默认空列表) -按顺序分配给图的输入节点的名称。

  • output_names(str 列表,默认空列表) -按顺序分配给图的输出节点的名称。

  • operator_export_type(枚举,默认无) -

    None 通常意味着 OperatorExportTypes.ONNX 。但是,如果 PyTorch 是使用 -DPYTORCH_ONNX_CAFFE2_BUNDLE 构建的,则 None 表示 OperatorExportTypes.ONNX_ATEN_FALLBACK

    • OperatorExportTypes.ONNX :将所有操作导出为常规 ONNX 操作(在默认 opset 域中)。

    • OperatorExportTypes.ONNX_FALLTHROUGH:尝试将默认 opset 域中的所有操作转换为标准 ONNX 操作。如果无法这样做(例如,因为尚未添加将特定 torch op 转换为 ONNX 的支持),则退回到将 op 导出到自定义 opset 域而不进行转换。适用于custom ops 以及 ATen 操作。为了使导出的模型可用,运行时必须支持这些非标准操作。

    • OperatorExportTypes.ONNX_ATEN :所有 ATen 操作(在 TorchScript 命名空间 “aten” 中)都导出为 ATen 操作(在 opset 域 “org.pytorch.aten” 中)。 ATen 是 PyTorch 的内置张量库,因此这指示运行时使用 PyTorch 对这些操作的实现。

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

      如果运算符实现中的数字差异导致 PyTorch 和 Caffe2 之间的行为存在巨大差异(这在未经训练的模型上更常见),这可能很有用。

    • OperatorExportTypes.ONNX_ATEN_FALLBACK :尝试将每个 ATen 操作(在 TorchScript 命名空间 “aten” 中)导出为常规 ONNX 操作。如果我们无法这样做(例如,因为尚未添加将特定 torch op 转换为 ONNX 的支持),则退回到导出 ATen op。有关上下文,请参阅 OperatorExportTypes.ONNX_ATEN 的文档。例如:

      graph(%0 : Float):
        %3 : int = prim::Constant[value=0]()
        # conversion unsupported
        %4 : Float = aten::triu(%0, %3)
        # conversion supported
        %5 : Float = aten::mul(%4, %0)
        return (%5)

      假设 ONNX 不支持aten::triu,这将被导出为:

      graph(%0 : Float):
        %1 : Long() = onnx::Constant[value={0}]()
        # not converted
        %2 : Float = aten::ATen[operator="triu"](%0, %1)
        # converted
        %3 : Float = onnx::Mul(%2, %0)
        return (%3)

      如果操作位于 TorchScript 命名空间 “quantized” 中,它将在 ONNX opset 域 “caffe2” 中导出。这些操作由 Quantization 中说明的模块生成。

      警告

      以这种方式导出的模型可能只能由 Caffe2 运行。

  • opset_version(int,默认 9) -必须是 == _onnx_main_opset or in _onnx_stable_opsets ,在 torch/onnx/symbolic_helper.py 中定义。

  • _retain_param_name(bool,默认真) -[已弃用并被忽略。将在下一个PyTorch版本中删除]

  • do_constant_folding(bool,默认假) -应用constant-folding 优化。 Constant-folding 将使用预先计算的常量节点替换一些具有所有常量输入的操作。

  • example_outputs(T或者T 的元组,其中 T 是张量或者可转换为张量,默认无) - [已弃用并被忽略。将在下一个 PyTorch 版本中删除],导出 ScriptModule 或 ScriptFunction 时必须提供,否则忽略。用于确定输出的类型和形状,而不跟踪模型的执行。单个对象被视为等同于一个元素的元组。

  • strip_doc_string(bool,默认真) -[已弃用并被忽略。将在下一个PyTorch版本中删除]

  • dynamic_axes(字典<字符串,字典<python:int,字符串>>或者字典<字符串,list(int)>,默认空字典) -

    默认情况下,导出的模型将所有输入和输出张量的形状设置为与args(以及需要该参数时的example_outputs)中给出的完全匹配。要将张量轴指定为动态的(即仅在运行时已知),请将 dynamic_axes 设置为具有模式的字典:

    • KEY (str):输入或输出名称。每个名称还必须在 input_namesoutput_names 中提供。

    • VALUE(字典或列表):如果是字典,则键是轴索引,值是轴名称。如果是列表,则每个元素都是一个轴索引。

    例如:

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                      input_names=["x"], output_names=["sum"])

    产生:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...

    尽管:

    torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                      input_names=["x"], output_names=["sum"],
                      dynamic_axes={
                          # dict value: manually named axes
                          "x": {0: "my_custom_axis_name"},
                          # list value: automatic names
                          "sum": [0],
                      })

    产生:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
  • keep_initializers_as_inputs(bool,默认无) -

    如果为 True,则导出图中的所有初始化程序(通常对应于参数)也将作为输入添加到图中。如果为 False,则初始化器不会作为输入添加到图中,并且仅将非参数输入添加为输入。这可能允许后端/运行时进行更好的优化(例如常量折叠)。

    如果 opset_version < 9,则初始化程序必须是图形输入的一部分,并且此参数将被忽略,并且行为将等效于将此参数设置为 True。

    如果没有,则自动选择行为,如下所示:

    • 如果 operator_export_type=OperatorExportTypes.ONNX ,则行为等效于将此参数设置为 False。

    • 否则,该行为等同于将此参数设置为 True。

  • custom_opsets(字典<str,整数>,默认空字典) -

    用来表示的字典

    带有模式的字典:

    • KEY (str): opset域名

    • VALUE (int):opset 版本

    如果自定义 opset 被 model 引用但未在本字典中提及,则 opset 版本设置为 1。

  • enable_onnx_checker(bool,默认真) -已弃用和忽略。将在下一个 Pytorch 版本中删除。

  • use_external_data_format(bool,默认假) -[已弃用并被忽略。将在下一个 Pytorch 版本中删除。] 如果为 True,则某些模型参数存储在外部数据文件中,而不是存储在 ONNX 模型文件本身中。由于 Protocol Buffers 施加的大小限制,大于 2GB 的模型无法导出到一个文件中。详情请参阅onnx.proto。如果为 True,则参数 f 必须是指定模型位置的字符串。外部数据文件将存储在与 f 相同的目录中。除非 operator_export_type=OperatorExportTypes.ONNX ,否则该参数将被忽略。

抛出

ONNXCheckerError-如果 ONNX 检查器检测到无效的 ONNX 图。即使提出此问题,仍会将模型导出到文件f

将模型导出为 ONNX 格式。如果 model 不是 torch.jit.ScriptModule 也不是 torch.jit.ScriptFunction ,则运行 model 一次,以便将其转换为要导出的 TorchScript 图(相当于 torch.jit.trace() )。因此,这对动态控制流的支持与 torch.jit.trace() 相同。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.onnx.export。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。