当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch CUDAExtension用法及代码示例


本文简要介绍python语言中 torch.utils.cpp_extension.CUDAExtension 的用法。

用法:

torch.utils.cpp_extension.CUDAExtension(name, sources, *args, **kwargs)

为 CUDA/C++ 创建 setuptools.Extension

创建setuptools.Extension 的便捷方法,使用最少(但通常足够)的参数来构建 CUDA/C++ 扩展。这包括 CUDA 包含路径、库路径和运行时库。

所有参数都转发到setuptools.Extension 构造函数。

示例

>>> from setuptools import setup
>>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension
>>> setup(
        name='cuda_extension',
        ext_modules=[
            CUDAExtension(
                    name='cuda_extension',
                    sources=['extension.cpp', 'extension_kernel.cu'],
                    extra_compile_args={'cxx': ['-g'],
                                        'nvcc': ['-O2']})
        ],
        cmdclass={
            'build_ext': BuildExtension
        })

计算能力:

默认情况下,扩展程序将被编译为在扩展程序构建过程中可见的所有卡牌以及 PTX 上运行。如果以后安装了新卡,则可能需要重新编译扩展。如果可见卡的计算能力 (CC) 比您的 nvcc 可以为其构建 fully-compiled 二进制文件的最新版本更新,Pytorch 将使 nvcc 回退到使用您的 nvcc 支持的最新版本 PTX 构建内核(见下文有关 PTX 的详细信息)。

您可以使用 TORCH_CUDA_ARCH_LIST 覆盖默认行为,以明确指定您希望扩展支持哪些 CC:

TORCH_CUDA_ARCH_LIST=”6.1 8.6” python build_my_extension.py TORCH_CUDA_ARCH_LIST=”5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX” python build_my_extension.py

+PTX 选项使扩展内核二进制文件包含指定 CC 的 PTX 指令。 PTX 是一种中间表示,它允许内核为任何 CC >= 指定的 CC 运行时编译(例如,8.6+PTX 生成的 PTX 可以为任何 CC >= 8.6 的 GPU 运行时编译)。这提高了二进制文件的前向兼容性。但是,依靠较旧的 PTX 通过runtime-compiling 为较新的 CC 提供前向兼容可能会适度降低这些较新 CC 的性能。如果您知道要定位的 GPU 的确切 CC,最好单独指定它们。例如,如果您希望您的扩展在 8.0 和 8.6 上运行,“8.0+PTX” 可以正常工作,因为它包含可以为 8.6 运行时编译的 PTX,但“8.0 8.6”会更好。

请注意,虽然可以包含所有受支持的拱门,但包含的拱门越多,构建过程就越慢,因为它将为每个拱门构建单独的内核映像。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.utils.cpp_extension.CUDAExtension。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。