1#ifndef PROTEUS_CORE_CUDA_H
2#define PROTEUS_CORE_CUDA_H
8#include <llvm/ADT/StringRef.h>
10#include <unordered_map>
19 void **,
const void *) =
nullptr;
22 cudaStream_t) =
nullptr;
27 void *DevPtr =
nullptr;
30 assert(DevPtr &&
"Expected non-null device pointer for global");
35 "Ensure the CUDA runtime is properly linked.");
39 dim3 BlockDim,
void **KernelArgs,
40 uint64_t ShmemSize, CUstream Stream) {
43 KernelArgs, ShmemSize, Stream);
46 reportFatalError(
"__proteus_cudaLaunchKernel_ptr is not initialized. Ensure "
47 "the CUDA runtime is properly linked.");
51 StringRef
KernelName,
const void *Image,
bool RelinkGlobalsByCopy,
52 const std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo) {
53 CUfunction KernelFunc;
57 if (RelinkGlobalsByCopy) {
58 for (
auto &[GlobalName, GVI] : VarNameToGlobalInfo) {
61 " without a concrete device address");
68 uint64_t PtrVal = (uint64_t)GVI.DevAddr;
79 dim3 BlockDim,
void **KernelArgs,
80 uint64_t ShmemSize, CUstream Stream) {
83 auto CUresultToCudaError = [](CUresult Res) -> cudaError_t {
87 case CUDA_ERROR_INVALID_VALUE:
88 return cudaErrorInvalidValue;
89 case CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES:
90 return cudaErrorLaunchOutOfResources;
91 case CUDA_ERROR_LAUNCH_TIMEOUT:
92 return cudaErrorLaunchTimeout;
93 case CUDA_ERROR_LAUNCH_FAILED:
94 return cudaErrorLaunchFailure;
95 case CUDA_ERROR_SHARED_OBJECT_INIT_FAILED:
96 return cudaErrorSharedObjectInitFailed;
97 case CUDA_ERROR_INVALID_HANDLE:
98 return cudaErrorInvalidResourceHandle;
99 case CUDA_ERROR_NOT_READY:
100 return cudaErrorNotReady;
101 case CUDA_ERROR_ILLEGAL_ADDRESS:
102 return cudaErrorIllegalAddress;
104 return cudaErrorUnknown;
109 constexpr size_t DefaultShmemSize = 48 * 1024;
111 if (ShmemSize >= DefaultShmemSize) {
113 KernelFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
117 CUresult Res =
cuLaunchKernel(KernelFunc, GridDim.x, GridDim.y, GridDim.z,
118 BlockDim.x, BlockDim.y, BlockDim.z, ShmemSize,
119 Stream, KernelArgs,
nullptr);
120 return static_cast<cudaError_t
>(CUresultToCudaError(Res));
CUresult CUDAAPI cuModuleGetFunction(CUfunction *Hfunc, CUmodule Hmod, const char *Name)
Definition CUDADriverAPI.cpp:120
CUresult CUDAAPI cuModuleLoadData(CUmodule *Module, const void *Image)
Definition CUDADriverAPI.cpp:106
CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr DstDevice, const void *SrcHost, size_t ByteCount)
Definition CUDADriverAPI.cpp:136
CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *Dptr, size_t *Bytes, CUmodule Hmod, const char *Name)
Definition CUDADriverAPI.cpp:128
CUresult CUDAAPI cuFuncSetAttribute(CUfunction Hfunc, CUfunction_attribute Attrib, int Value)
Definition CUDADriverAPI.cpp:152
CUresult CUDAAPI cuLaunchKernel(CUfunction F, unsigned int GridDimX, unsigned int GridDimY, unsigned int GridDimZ, unsigned int BlockDimX, unsigned int BlockDimY, unsigned int BlockDimZ, unsigned int SharedMemBytes, CUstream HStream, void **KernelParams, void **Extra)
Definition CUDADriverAPI.cpp:160
void char * KernelName
Definition CompilerInterfaceDevice.cpp:59
#define proteusCuErrCheck(CALL)
Definition UtilsCUDA.h:28
Definition MemoryCache.h:27
cudaError_t(* __proteus_cudaGetSymbolAddress_ptr)(void **, const void *)
Definition CoreDeviceCUDA.h:18
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
cudaError_t launchKernelDirect(void *KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:38
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:78
void * resolveDeviceGlobalAddr(const void *Addr)
Definition CoreDeviceCUDA.h:26
cudaError_t(* __proteus_cudaLaunchKernel_ptr)(const void *, dim3, dim3, void **, size_t, cudaStream_t)
Definition CoreDeviceCUDA.h:20
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.h:50