当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python PyTorch load用法及代码示例

本文简要介绍python语言中 torch.utils.cpp_extension.load 的用法。

用法:

torch.utils.cpp_extension.load(name, sources, extra_cflags=None, extra_cuda_cflags=None, extra_ldflags=None, extra_include_paths=None, build_directory=None, verbose=False, with_cuda=None, is_python_module=True, is_standalone=False, keep_intermediates=True)

参数

  • name-要构建的扩展名。这必须与 pybind11 模块的名称相同!

  • sources-C++ 源文件的相对或绝对路径列表。

  • extra_cflags-要转发到构建的编译器标志的可选列表。

  • extra_cuda_cflags-在构建 CUDA 源时转发到 nvcc 的编译器标志的可选列表。

  • extra_ldflags-要转发到构建的链接器标志的可选列表。

  • extra_include_paths-要转发到构建的包含目录的可选列表。

  • build_directory-用作构建工作区的可选路径。

  • verbose-如果 True ,打开加载步骤的详细记录。

  • with_cuda-确定是否将 CUDA 标头和库添加到构建中。如果设置为 None(默认),则此值将根据 sources 中是否存在 .cu.cuh 自动确定。将其设置为 True` 以强制包含 CUDA 标头和库。

  • is_python_module-如果True(默认),将生成的共享库作为 Python 模块导入。如果 False ,行为取决于 is_standalone

  • is_standalone-如果False(默认)将构建的扩展作为普通动态库加载到进程中。如果 True ,构建一个独立的可执行文件。

返回

将加载的 PyTorch 扩展作为 Python 模块返回。

如果 is_python_moduleFalse 并且 is_standaloneFalse

什么都不返回。 (共享库作为副作用加载到进程中。)

如果 is_standaloneTrue

返回可执行文件的路径。 (在 Windows 上,TORCH_LIB_PATH 作为副作用添加到 PATH 环境变量中。)

返回类型

如果 is_python_moduleTrue

加载PyTorch C++ 扩展just-in-time (JIT)。

要加载扩展,会发出一个 Ninja 构建文件,用于将给定的源代码编译成动态库。这个库随后作为一个模块加载到当前的 Python 进程中,并从这个函数返回,准备使用。

默认情况下,生成文件和生成的库编译到的目录是 <tmp>/torch_extensions/<name> ,其中 <tmp> 是当前平台上的临时文件夹, <name> 是扩展名。可以通过两种方式覆盖此位置。首先,如果设置了TORCH_EXTENSIONS_DIR 环境变量,它将替换<tmp>/torch_extensions,并且所有扩展都将编译到该目录的子文件夹中。其次,如果提供了该函数的build_directory 参数,它将覆盖整个路径,即库将直接编译到该文件夹中。

要编译源代码,使用默认系统编译器 (c++),可以通过设置 CXX 环境变量来覆盖它。要将附加参数传递给编译过程,可以提供extra_cflagsextra_ldflags。例如,要使用优化编译扩展,请传递 extra_cflags=['-O3'] 。您还可以使用extra_cflags 传递进一步的包含目录。

提供了混合编译的 CUDA 支持。只需将 CUDA 源文件(.cu.cuh)与其他源一起传递。将使用 nvcc 而不是 C++ 编译器检测和编译此类文件。这包括将 CUDA lib64 目录作为库目录传递,并链接 cudart 。您可以通过 extra_cuda_cflags 将其他标志传递给 nvcc,就像使用 C++ 的 extra_cflags 一样。使用了各种用于查找 CUDA 安装目录的启发式方法,通常可以正常工作。如果没有,设置 CUDA_HOME 环境变量是最安全的选择。

示例

>>> from torch.utils.cpp_extension import load
>>> module = load(
        name='extension',
        sources=['extension.cpp', 'extension_kernel.cu'],
        extra_cflags=['-O2'],
        verbose=True)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.utils.cpp_extension.load。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。