當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch export用法及代碼示例

本文簡要介紹python語言中 torch.onnx.export 的用法。

用法:

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=None, opset_version=None, _retain_param_name=None, do_constant_folding=True, example_outputs=None, strip_doc_string=None, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=None, use_external_data_format=None)

參數

  • model(torch.nn.Module,torch.jit.ScriptModule或者torch.jit.ScriptFunction) -要導出的模型。

  • args(tuple或者torch.Tensor) -

    args 可以構造為:

    1. 隻有一個元組的參數:

      args = (x, y, z)

    元組應包含模型輸入,以便 model(*args) 是模型的有效調用。任何非張量參數都將 hard-coded 進入導出模型;任何張量參數都將成為導出模型的輸入,按照它們在元組中出現的順序。

    1. 張量:

      args = torch.Tensor([1])

    這相當於該張量的一元元組。

    1. 以命名參數字典結尾的參數元組:

      args = (x,
              {'y': input_y,
               'z': input_z})

    元組中除了最後一個元素之外的所有元素都將作為非關鍵字參數傳遞,命名參數將從最後一個元素開始設置。如果字典中不存在命名參數,則為其分配默認值,如果未提供默認值,則為 None。

    注意

    如果字典是 args 元組的最後一個元素,它將被解釋為包含命名參數。為了將 dict 作為最後一個非關鍵字 arg 傳遞,請提供一個空 dict 作為 args 元組的最後一個元素。例如,而不是:

    torch.onnx.export(
        model,
        (x,
         # WRONG: will be interpreted as named arguments
         {y: z}),
        "test.onnx.pb")

    寫:

    torch.onnx.export(
        model,
        (x,
         {y: z},
         {}),
        "test.onnx.pb")
  • f-file-like 對象(這樣 f.fileno() 返回文件說明符)或包含文件名的字符串。二進製協議緩衝區將寫入此文件。

  • export_params(bool,默認真) -如果為 True,則將導出所有參數。如果要導出未經訓練的模型,請將其設置為 False。在這種情況下,導出的模型將首先將其所有參數作為參數,其順序由 model.state_dict().values() 指定

  • verbose(bool,默認假) -如果為 True,則打印正在導出到標準輸出的模型的說明。此外,最終的 ONNX 圖將包括來自導出模型的字段 doc_string`,其中提到了 model 的源代碼位置。

  • training(枚舉,默認 TrainingMode.EVAL) -

    • TrainingMode.EVAL :以推理模式導出模型。

    • TrainingMode.PRESERVE :如果 model.training 為 False,則在推理模式下導出模型;如果 model.training 為 True,則在訓練模式下導出模型。

    • TrainingMode.TRAINING:在訓練模式下導出模型。禁用可能幹擾訓練的優化。

  • input_names(str 列表,默認空列表) -按順序分配給圖的輸入節點的名稱。

  • output_names(str 列表,默認空列表) -按順序分配給圖的輸出節點的名稱。

  • operator_export_type(枚舉,默認無) -

    None 通常意味著 OperatorExportTypes.ONNX 。但是,如果 PyTorch 是使用 -DPYTORCH_ONNX_CAFFE2_BUNDLE 構建的,則 None 表示 OperatorExportTypes.ONNX_ATEN_FALLBACK

    • OperatorExportTypes.ONNX :將所有操作導出為常規 ONNX 操作(在默認 opset 域中)。

    • OperatorExportTypes.ONNX_FALLTHROUGH:嘗試將默認 opset 域中的所有操作轉換為標準 ONNX 操作。如果無法這樣做(例如,因為尚未添加將特定 torch op 轉換為 ONNX 的支持),則退回到將 op 導出到自定義 opset 域而不進行轉換。適用於custom ops 以及 ATen 操作。為了使導出的模型可用,運行時必須支持這些非標準操作。

    • OperatorExportTypes.ONNX_ATEN :所有 ATen 操作(在 TorchScript 命名空間 “aten” 中)都導出為 ATen 操作(在 opset 域 “org.pytorch.aten” 中)。 ATen 是 PyTorch 的內置張量庫,因此這指示運行時使用 PyTorch 對這些操作的實現。

      警告

      以這種方式導出的模型可能隻能由 Caffe2 運行。

      如果運算符實現中的數字差異導致 PyTorch 和 Caffe2 之間的行為存在巨大差異(這在未經訓練的模型上更常見),這可能很有用。

    • OperatorExportTypes.ONNX_ATEN_FALLBACK :嘗試將每個 ATen 操作(在 TorchScript 命名空間 “aten” 中)導出為常規 ONNX 操作。如果我們無法這樣做(例如,因為尚未添加將特定 torch op 轉換為 ONNX 的支持),則退回到導出 ATen op。有關上下文,請參閱 OperatorExportTypes.ONNX_ATEN 的文檔。例如:

      graph(%0 : Float):
        %3 : int = prim::Constant[value=0]()
        # conversion unsupported
        %4 : Float = aten::triu(%0, %3)
        # conversion supported
        %5 : Float = aten::mul(%4, %0)
        return (%5)

      假設 ONNX 不支持aten::triu,這將被導出為:

      graph(%0 : Float):
        %1 : Long() = onnx::Constant[value={0}]()
        # not converted
        %2 : Float = aten::ATen[operator="triu"](%0, %1)
        # converted
        %3 : Float = onnx::Mul(%2, %0)
        return (%3)

      如果操作位於 TorchScript 命名空間 “quantized” 中,它將在 ONNX opset 域 “caffe2” 中導出。這些操作由 Quantization 中說明的模塊生成。

      警告

      以這種方式導出的模型可能隻能由 Caffe2 運行。

  • opset_version(int,默認 9) -必須是 == _onnx_main_opset or in _onnx_stable_opsets ,在 torch/onnx/symbolic_helper.py 中定義。

  • _retain_param_name(bool,默認真) -[已棄用並被忽略。將在下一個PyTorch版本中刪除]

  • do_constant_folding(bool,默認假) -應用constant-folding 優化。 Constant-folding 將使用預先計算的常量節點替換一些具有所有常量輸入的操作。

  • example_outputs(T或者T 的元組,其中 T 是張量或者可轉換為張量,默認無) - [已棄用並被忽略。將在下一個 PyTorch 版本中刪除],導出 ScriptModule 或 ScriptFunction 時必須提供,否則忽略。用於確定輸出的類型和形狀,而不跟蹤模型的執行。單個對象被視為等同於一個元素的元組。

  • strip_doc_string(bool,默認真) -[已棄用並被忽略。將在下一個PyTorch版本中刪除]

  • dynamic_axes(字典<字符串,字典<python:int,字符串>>或者字典<字符串,list(int)>,默認空字典) -

    默認情況下,導出的模型將所有輸入和輸出張量的形狀設置為與args(以及需要該參數時的example_outputs)中給出的完全匹配。要將張量軸指定為動態的(即僅在運行時已知),請將 dynamic_axes 設置為具有模式的字典:

    • KEY (str):輸入或輸出名稱。每個名稱還必須在 input_namesoutput_names 中提供。

    • VALUE(字典或列表):如果是字典,則鍵是軸索引,值是軸名稱。如果是列表,則每個元素都是一個軸索引。

    例如:

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                      input_names=["x"], output_names=["sum"])

    產生:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_value: 2  # axis 0
    ...

    盡管:

    torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                      input_names=["x"], output_names=["sum"],
                      dynamic_axes={
                          # dict value: manually named axes
                          "x": {0: "my_custom_axis_name"},
                          # list value: automatic names
                          "sum": [0],
                      })

    產生:

    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
  • keep_initializers_as_inputs(bool,默認無) -

    如果為 True,則導出圖中的所有初始化程序(通常對應於參數)也將作為輸入添加到圖中。如果為 False,則初始化器不會作為輸入添加到圖中,並且僅將非參數輸入添加為輸入。這可能允許後端/運行時進行更好的優化(例如常量折疊)。

    如果 opset_version < 9,則初始化程序必須是圖形輸入的一部分,並且此參數將被忽略,並且行為將等效於將此參數設置為 True。

    如果沒有,則自動選擇行為,如下所示:

    • 如果 operator_export_type=OperatorExportTypes.ONNX ,則行為等效於將此參數設置為 False。

    • 否則,該行為等同於將此參數設置為 True。

  • custom_opsets(字典<str,整數>,默認空字典) -

    用來表示的字典

    帶有模式的字典:

    • KEY (str): opset域名

    • VALUE (int):opset 版本

    如果自定義 opset 被 model 引用但未在本字典中提及,則 opset 版本設置為 1。

  • enable_onnx_checker(bool,默認真) -已棄用和忽略。將在下一個 Pytorch 版本中刪除。

  • use_external_data_format(bool,默認假) -[已棄用並被忽略。將在下一個 Pytorch 版本中刪除。] 如果為 True,則某些模型參數存儲在外部數據文件中,而不是存儲在 ONNX 模型文件本身中。由於 Protocol Buffers 施加的大小限製,大於 2GB 的模型無法導出到一個文件中。詳情請參閱onnx.proto。如果為 True,則參數 f 必須是指定模型位置的字符串。外部數據文件將存儲在與 f 相同的目錄中。除非 operator_export_type=OperatorExportTypes.ONNX ,否則該參數將被忽略。

拋出

ONNXCheckerError-如果 ONNX 檢查器檢測到無效的 ONNX 圖。即使提出此問題,仍會將模型導出到文件f

將模型導出為 ONNX 格式。如果 model 不是 torch.jit.ScriptModule 也不是 torch.jit.ScriptFunction ,則運行 model 一次,以便將其轉換為要導出的 TorchScript 圖(相當於 torch.jit.trace() )。因此,這對動態控製流的支持與 torch.jit.trace() 相同。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.onnx.export。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。