Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
DispatcherCUDA.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_CUDA_HPP
2#define PROTEUS_FRONTEND_DISPATCHER_CUDA_HPP
3
4#if PROTEUS_ENABLE_CUDA
5
8
9namespace proteus {
10
11class DispatcherCUDA : public Dispatcher {
12public:
13 static DispatcherCUDA &instance() {
14 static DispatcherCUDA D;
15 return D;
16 }
17
18 void compile(std::unique_ptr<Module> M) override {
19 // CMake finds LIBDEVICE_BC_PATH.
20 auto LibDeviceBuffer = llvm::MemoryBuffer::getFile(LIBDEVICE_BC_PATH);
21 auto LibDeviceModule = llvm::parseBitcodeFile(
22 LibDeviceBuffer->get()->getMemBufferRef(), M->getContext());
23
24 llvm::Linker linker(*M);
25 linker.linkInModule(std::move(LibDeviceModule.get()));
26
27 Library = Jit.compileOnly(*M);
28 if (!Library)
29 PROTEUS_FATAL_ERROR("Expected non-null object library");
30 }
31
32 DispatchResult launch(StringRef KernelName, LaunchDims GridDim,
33 LaunchDims BlockDim, ArrayRef<void *> KernelArgs,
34 uint64_t ShmemSize, void *Stream) override {
36 KernelName, Library->getBufferStart(),
37 /*RelinkGlobalsByCopy*/ false,
38 /* VarNameToDevPtr */ {});
39
40 dim3 CudaGridDim = {GridDim.X, GridDim.Y, GridDim.Z};
41 dim3 CudaBlockDim = {BlockDim.X, BlockDim.Y, BlockDim.Z};
42 cudaStream_t CudaStream = reinterpret_cast<cudaStream_t>(Stream);
43
44 void **KernelArgsPtrs = const_cast<void **>(KernelArgs.data());
45 return proteus::launchKernelFunction(KernelFunc, CudaGridDim, CudaBlockDim,
46 KernelArgsPtrs, ShmemSize, CudaStream);
47 }
48
49protected:
50 void *getFunctionAddress(StringRef) override {
51 PROTEUS_FATAL_ERROR("CUDA does not support getFunctionAddress");
52 }
53
54private:
55 JitEngineDeviceCUDA &Jit;
56 DispatcherCUDA() : Jit(JitEngineDeviceCUDA::instance()) {}
57};
58
59} // namespace proteus
60
61#endif
62
63#endif // PROTEUS_FRONTEND_DISPATCHER_CUDA_HPP
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
auto & Jit
Definition CompilerInterfaceDevice.cpp:54
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
Definition Dispatcher.cpp:14
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:51
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, const void * > &VarNameToDevPtr)
Definition CoreDeviceCUDA.hpp:27
Definition Dispatcher.hpp:14
unsigned Z
Definition Dispatcher.hpp:15
unsigned Y
Definition Dispatcher.hpp:15
unsigned X
Definition Dispatcher.hpp:15