Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
DispatcherCUDA.h
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_CUDA_H
2#define PROTEUS_FRONTEND_DISPATCHER_CUDA_H
3
4#if PROTEUS_ENABLE_CUDA
5
9
10namespace proteus {
11
12class DispatcherCUDA : public Dispatcher {
13public:
14 static DispatcherCUDA &instance() {
15 static DispatcherCUDA D;
16 return D;
17 }
18
19 std::unique_ptr<MemoryBuffer> compile(std::unique_ptr<LLVMContext> Ctx,
20 std::unique_ptr<Module> Mod,
21 const HashT &ModuleHash,
22 bool DisableIROpt = false) override {
23 // This is necessary to ensure Ctx outlives M. Setting [[maybe_unused]] can
24 // trigger a lifetime bug.
25 auto CtxOwner = std::move(Ctx);
26 auto ModOwner = std::move(Mod);
27
28 // CMake finds LIBDEVICE_BC_PATH.
29 auto LibDeviceBuffer = llvm::MemoryBuffer::getFile(LIBDEVICE_BC_PATH);
30 auto LibDeviceModule = llvm::parseBitcodeFile(
31 LibDeviceBuffer->get()->getMemBufferRef(), ModOwner->getContext());
32
33 llvm::Linker linker(*ModOwner);
34 linker.linkInModule(std::move(LibDeviceModule.get()),
35 llvm::Linker::Flags::LinkOnlyNeeded);
36
37 std::unique_ptr<MemoryBuffer> ObjectModule =
38 Jit.compileOnly(*ModOwner, DisableIROpt);
39 if (!ObjectModule)
40 reportFatalError("Expected non-null object library");
41
42 ObjectCache->store(
43 ModuleHash, CacheEntry::staticObject(ObjectModule->getMemBufferRef()));
44
45 return ObjectModule;
46 }
47
48 std::unique_ptr<CompiledLibrary>
49 lookupCompiledLibrary(const HashT &ModuleHash) override {
50 return ObjectCache->lookup(ModuleHash);
51 }
52
53 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
54 LaunchDims BlockDim, void *KernelArgs[],
55 uint64_t ShmemSize, void *Stream) override {
56 dim3 CudaGridDim = {GridDim.X, GridDim.Y, GridDim.Z};
57 dim3 CudaBlockDim = {BlockDim.X, BlockDim.Y, BlockDim.Z};
58 cudaStream_t CudaStream = reinterpret_cast<cudaStream_t>(Stream);
59
61 reinterpret_cast<cudaFunction_t>(KernelFunc), CudaGridDim, CudaBlockDim,
62 KernelArgs, ShmemSize, CudaStream);
63 }
64
65 StringRef getDeviceArch() const override { return Jit.getDeviceArch(); }
66
67 void *getFunctionAddress(const std::string &KernelName,
68 const HashT &ModuleHash,
69 CompiledLibrary &Library) override {
70 auto GetKernelFunc = [&]() {
71 // Hash the kernel name to get a unique id.
72 HashT HashValue = hash(KernelName, ModuleHash);
73
74 if (auto KernelFunc = CodeCache.lookup(HashValue))
75 return KernelFunc;
76
78 KernelName, Library.ObjectModule->getBufferStart(),
79 /*RelinkGlobalsByCopy*/ false,
80 /* VarNameToGlobalInfo */ {});
81
82 CodeCache.insert(HashValue, KernelFunc, KernelName);
83
84 return KernelFunc;
85 };
86
87 auto KernelFunc = GetKernelFunc();
88 return KernelFunc;
89 }
90
91 void registerDynamicLibrary(const HashT &, const std::string &) override {
92 reportFatalError("Dispatch CUDA does not support registerDynamicLibrary");
93 }
94
95 ~DispatcherCUDA() {
96 CodeCache.printStats();
97 CodeCache.printKernelTrace();
98 ObjectCache->printStats();
99 }
100
101private:
102 JitEngineDeviceCUDA &Jit;
103 DispatcherCUDA()
104 : Dispatcher("DispatcherCUDA", TargetModelType::CUDA),
105 Jit(JitEngineDeviceCUDA::instance()) {}
106 MemoryCache<CUfunction> CodeCache{"DispatcherCUDA"};
107};
108
109} // namespace proteus
110
111#endif
112
113#endif // PROTEUS_FRONTEND_DISPATCHER_CUDA_H
void char * KernelName
Definition CompilerInterfaceDevice.cpp:54
JitEngineHost & Jit
Definition CompilerInterfaceHost.cpp:25
Definition MemoryCache.h:26
TargetModelType
Definition TargetModel.h:8
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.h:142
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:78
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.h:50
Definition Dispatcher.h:21
unsigned Z
Definition Dispatcher.h:22
unsigned Y
Definition Dispatcher.h:22
unsigned X
Definition Dispatcher.h:22