當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch load用法及代碼示例

本文簡要介紹python語言中 torch.utils.cpp_extension.load 的用法。

用法:

torch.utils.cpp_extension.load(name, sources, extra_cflags=None, extra_cuda_cflags=None, extra_ldflags=None, extra_include_paths=None, build_directory=None, verbose=False, with_cuda=None, is_python_module=True, is_standalone=False, keep_intermediates=True)

參數

  • name-要構建的擴展名。這必須與 pybind11 模塊的名稱相同!

  • sources-C++ 源文件的相對或絕對路徑列表。

  • extra_cflags-要轉發到構建的編譯器標誌的可選列表。

  • extra_cuda_cflags-在構建 CUDA 源時轉發到 nvcc 的編譯器標誌的可選列表。

  • extra_ldflags-要轉發到構建的鏈接器標誌的可選列表。

  • extra_include_paths-要轉發到構建的包含目錄的可選列表。

  • build_directory-用作構建工作區的可選路徑。

  • verbose-如果 True ,打開加載步驟的詳細記錄。

  • with_cuda-確定是否將 CUDA 標頭和庫添加到構建中。如果設置為 None(默認),則此值將根據 sources 中是否存在 .cu.cuh 自動確定。將其設置為 True` 以強製包含 CUDA 標頭和庫。

  • is_python_module-如果True(默認),將生成的共享庫作為 Python 模塊導入。如果 False ,行為取決於 is_standalone

  • is_standalone-如果False(默認)將構建的擴展作為普通動態庫加載到進程中。如果 True ,構建一個獨立的可執行文件。

返回

將加載的 PyTorch 擴展作為 Python 模塊返回。

如果 is_python_moduleFalse 並且 is_standaloneFalse

什麽都不返回。 (共享庫作為副作用加載到進程中。)

如果 is_standaloneTrue

返回可執行文件的路徑。 (在 Windows 上,TORCH_LIB_PATH 作為副作用添加到 PATH 環境變量中。)

返回類型

如果 is_python_moduleTrue

加載PyTorch C++ 擴展just-in-time (JIT)。

要加載擴展,會發出一個 Ninja 構建文件,用於將給定的源代碼編譯成動態庫。這個庫隨後作為一個模塊加載到當前的 Python 進程中,並從這個函數返回,準備使用。

默認情況下,生成文件和生成的庫編譯到的目錄是 <tmp>/torch_extensions/<name> ,其中 <tmp> 是當前平台上的臨時文件夾, <name> 是擴展名。可以通過兩種方式覆蓋此位置。首先,如果設置了TORCH_EXTENSIONS_DIR 環境變量,它將替換<tmp>/torch_extensions,並且所有擴展都將編譯到該目錄的子文件夾中。其次,如果提供了該函數的build_directory 參數,它將覆蓋整個路徑,即庫將直接編譯到該文件夾中。

要編譯源代碼,使用默認係統編譯器 (c++),可以通過設置 CXX 環境變量來覆蓋它。要將附加參數傳遞給編譯過程,可以提供extra_cflagsextra_ldflags。例如,要使用優化編譯擴展,請傳遞 extra_cflags=['-O3'] 。您還可以使用extra_cflags 傳遞進一步的包含目錄。

提供了混合編譯的 CUDA 支持。隻需將 CUDA 源文件(.cu.cuh)與其他源一起傳遞。將使用 nvcc 而不是 C++ 編譯器檢測和編譯此類文件。這包括將 CUDA lib64 目錄作為庫目錄傳遞,並鏈接 cudart 。您可以通過 extra_cuda_cflags 將其他標誌傳遞給 nvcc,就像使用 C++ 的 extra_cflags 一樣。使用了各種用於查找 CUDA 安裝目錄的啟發式方法,通常可以正常工作。如果沒有,設置 CUDA_HOME 環境變量是最安全的選擇。

示例

>>> from torch.utils.cpp_extension import load
>>> module = load(
        name='extension',
        sources=['extension.cpp', 'extension_kernel.cu'],
        extra_cflags=['-O2'],
        verbose=True)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.utils.cpp_extension.load。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。