Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
DispatcherCUDA.h
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_CUDA_H
2#define PROTEUS_FRONTEND_DISPATCHER_CUDA_H
3
4#if PROTEUS_ENABLE_CUDA
5
10
11namespace proteus {
12
13class DispatcherCUDA : public Dispatcher {
14public:
15 static DispatcherCUDA &instance() {
16 static DispatcherCUDA D;
17 return D;
18 }
19
20 std::unique_ptr<MemoryBuffer> compile(std::unique_ptr<LLVMContext> Ctx,
21 std::unique_ptr<Module> Mod,
22 const HashT &ModuleHash,
23 bool DisableIROpt = false) override {
24 TIMESCOPE(DispatcherCUDA, compile);
25 // This is necessary to ensure Ctx outlives M. Setting [[maybe_unused]] can
26 // trigger a lifetime bug.
27 auto CtxOwner = std::move(Ctx);
28 auto ModOwner = std::move(Mod);
29
30 // CMake finds LIBDEVICE_BC_PATH.
31 auto LibDeviceBuffer = llvm::MemoryBuffer::getFile(LIBDEVICE_BC_PATH);
32 auto LibDeviceModule = llvm::parseBitcodeFile(
33 LibDeviceBuffer->get()->getMemBufferRef(), ModOwner->getContext());
34
35 llvm::Linker linker(*ModOwner);
36 linker.linkInModule(std::move(LibDeviceModule.get()),
37 llvm::Linker::Flags::LinkOnlyNeeded);
38
39 std::unique_ptr<MemoryBuffer> ObjectModule =
40 Jit.compileOnly(*ModOwner, DisableIROpt);
41 if (!ObjectModule)
42 reportFatalError("Expected non-null object library");
43
44 ObjectCache->store(
45 ModuleHash, CacheEntry::staticObject(ObjectModule->getMemBufferRef()));
46
47 return ObjectModule;
48 }
49
50 std::unique_ptr<CompiledLibrary>
51 lookupCompiledLibrary(const HashT &ModuleHash) override {
52 return ObjectCache->lookup(ModuleHash);
53 }
54
55 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
56 LaunchDims BlockDim, void *KernelArgs[],
57 uint64_t ShmemSize, void *Stream) override {
58 TIMESCOPE(DispatcherCUDA, launch);
59 dim3 CudaGridDim = {GridDim.X, GridDim.Y, GridDim.Z};
60 dim3 CudaBlockDim = {BlockDim.X, BlockDim.Y, BlockDim.Z};
61 cudaStream_t CudaStream = reinterpret_cast<cudaStream_t>(Stream);
62
64 reinterpret_cast<cudaFunction_t>(KernelFunc), CudaGridDim, CudaBlockDim,
65 KernelArgs, ShmemSize, CudaStream);
66 }
67
68 StringRef getDeviceArch() const override { return Jit.getDeviceArch(); }
69
70 void *getFunctionAddress(const std::string &KernelName,
71 const HashT &ModuleHash,
72 CompiledLibrary &Library) override {
73 TIMESCOPE(DispatcherCUDA, getFunctionAddress);
74 auto GetKernelFunc = [&]() {
75 // Hash the kernel name to get a unique id.
76 HashT HashValue = hash(KernelName, ModuleHash);
77
78 if (auto KernelFunc = CodeCache.lookup(HashValue))
79 return KernelFunc;
80
82 KernelName, Library.ObjectModule->getBufferStart(),
83 /*RelinkGlobalsByCopy*/ false,
84 /* VarNameToGlobalInfo */ {});
85
86 CodeCache.insert(HashValue, KernelFunc, KernelName);
87
88 return KernelFunc;
89 };
90
91 auto KernelFunc = GetKernelFunc();
92 return KernelFunc;
93 }
94
95 void registerDynamicLibrary(const HashT &, const std::string &) override {
96 reportFatalError("Dispatch CUDA does not support registerDynamicLibrary");
97 }
98
99 void registerObject(const HashT &HashValue,
100 const llvm::MemoryBufferRef &Obj) override {
101 ObjectCache->store(HashValue, CacheEntry::staticObject(Obj));
102 }
103
104 ~DispatcherCUDA() {
105 CodeCache.printStats();
106 CodeCache.printKernelTrace();
107 ObjectCache->printStats();
108 }
109
110private:
111 JitEngineDeviceCUDA &Jit;
112 DispatcherCUDA()
113 : Dispatcher("DispatcherCUDA", TargetModelType::CUDA),
114 Jit(JitEngineDeviceCUDA::instance()) {}
115 MemoryCache<CUfunction> CodeCache{"DispatcherCUDA"};
116};
117
118} // namespace proteus
119
120#endif
121
122#endif // PROTEUS_FRONTEND_DISPATCHER_CUDA_H
void char * KernelName
Definition CompilerInterfaceDevice.cpp:55
JitEngineHost & Jit
Definition CompilerInterfaceHost.cpp:26
#define TIMESCOPE(...)
Definition TimeTracing.h:66
Definition MemoryCache.h:27
TargetModelType
Definition TargetModel.h:8
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.h:168
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:78
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.h:50
Definition Dispatcher.h:22
unsigned Z
Definition Dispatcher.h:23
unsigned Y
Definition Dispatcher.h:23
unsigned X
Definition Dispatcher.h:23