當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python mxnet.rtc.CudaModule用法及代碼示例

用法:

class mxnet.rtc.CudaModule(source, options=(), exports=())

參數

  • source(str) - 完整的源代碼。
  • options(tuple of str) - 編譯器標誌。例如,使用 “-I/usr/local/cuda/include” 添加 cuda 標頭以包含路徑。
  • exports(tuple of str) - 導出內核名稱。

基礎:object

從 Python 編譯和運行 CUDA 代碼。

在 CUDA 7.5 中,您需要在內核定義前添加“extern “C””以避免名稱混淆:

source = r'''
extern "C" __global__ void axpy(const float *x, float *y, float alpha) {
    int i = threadIdx.x + blockIdx.x * blockDim.x;
    y[i] += alpha * x[i];
}
'''
module = mx.rtc.CudaModule(source)
func = module.get_kernel("axpy", "const float *x, float *y, float alpha")
x = mx.nd.ones((10,), ctx=mx.gpu(0))
y = mx.nd.zeros((10,), ctx=mx.gpu(0))
func.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
print(y)

從 CUDA 8.0 開始,您可以改為按名稱導出函數。這也允許您使用模板:

source = r'''
template<typename DType>
__global__ void axpy(const DType *x, DType *y, DType alpha) {
    int i = threadIdx.x + blockIdx.x * blockDim.x;
    y[i] += alpha * x[i];
}
'''
module = mx.rtc.CudaModule(source, exports=['axpy<float>', 'axpy<double>'])
func32 = module.get_kernel("axpy<float>", "const float *x, float *y, float alpha")
x = mx.nd.ones((10,), dtype='float32', ctx=mx.gpu(0))
y = mx.nd.zeros((10,), dtype='float32', ctx=mx.gpu(0))
func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
print(y)

func64 = module.get_kernel("axpy<double>", "const double *x, double *y, double alpha")
x = mx.nd.ones((10,), dtype='float64', ctx=mx.gpu(0))
y = mx.nd.zeros((10,), dtype='float64', ctx=mx.gpu(0))
func32.launch([x, y, 3.0], mx.gpu(0), (1, 1, 1), (10, 1, 1))
print(y)

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.rtc.CudaModule。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。