Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
DispatcherCUDA.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_CUDA_HPP
2#define PROTEUS_FRONTEND_DISPATCHER_CUDA_HPP
3
4#if PROTEUS_ENABLE_CUDA
5
8
9namespace proteus {
10
11class DispatcherCUDA : public Dispatcher {
12public:
13 static DispatcherCUDA &instance() {
14 static DispatcherCUDA D;
15 return D;
16 }
17
18 std::unique_ptr<MemoryBuffer>
19 compile([[maybe_unused]] std::unique_ptr<LLVMContext> Ctx,
20 std::unique_ptr<Module> Mod, HashT ModuleHash,
21 bool DisableIROpt = false) override {
22 // This is necessary to ensure Ctx outlives M. Setting [[maybe_unused]] can
23 // trigger a lifetime bug.
24 auto CtxOwner = std::move(Ctx);
25 auto ModOwner = std::move(Mod);
26
27 // CMake finds LIBDEVICE_BC_PATH.
28 auto LibDeviceBuffer = llvm::MemoryBuffer::getFile(LIBDEVICE_BC_PATH);
30 LibDeviceBuffer->get()->getMemBufferRef(), ModOwner->getContext());
31
33 linker.linkInModule(std::move(LibDeviceModule.get()),
34 llvm::Linker::Flags::LinkOnlyNeeded);
35
36 std::unique_ptr<MemoryBuffer> ObjectModule =
37 Jit.compileOnly(*ModOwner, DisableIROpt);
38 if (!ObjectModule)
39 PROTEUS_FATAL_ERROR("Expected non-null object library");
40
41 ObjectCache.store(ModuleHash, ObjectModule->getMemBufferRef());
42
43 return ObjectModule;
44 }
45
46 std::unique_ptr<CompiledLibrary>
47 lookupCompiledLibrary(HashT ModuleHash) override {
48 return ObjectCache.lookup(ModuleHash);
49 }
50
51 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
53 uint64_t ShmemSize, void *Stream) override {
54 dim3 CudaGridDim = {GridDim.X, GridDim.Y, GridDim.Z};
55 dim3 CudaBlockDim = {BlockDim.X, BlockDim.Y, BlockDim.Z};
56 cudaStream_t CudaStream = reinterpret_cast<cudaStream_t>(Stream);
57
58 void **KernelArgsPtrs = const_cast<void **>(KernelArgs.data());
62 }
63
64 StringRef getDeviceArch() const override { return Jit.getDeviceArch(); }
65
66 void *getFunctionAddress(StringRef KernelName, HashT ModuleHash,
67 CompiledLibrary &Library) override {
68 auto GetKernelFunc = [&]() {
69 // Hash the kernel name to get a unique id.
70 HashT HashValue = hash(KernelName, ModuleHash);
71
72 if (auto KernelFunc = CodeCache.lookup(HashValue))
73 return KernelFunc;
74
76 KernelName, Library.ObjectModule->getBufferStart(),
77 /*RelinkGlobalsByCopy*/ false,
78 /* VarNameToGlobalInfo */ {});
79
80 CodeCache.insert(HashValue, KernelFunc, KernelName);
81
82 return KernelFunc;
83 };
84
86 return KernelFunc;
87 }
88
89 void registerDynamicLibrary(HashT, const SmallString<128> &) override {
91 "Dispatch CUDA does not support registerDynamicLibrary");
92 }
93
95 CodeCache.printStats();
96 ObjectCache.printStats();
97 }
98
99private:
100 JitEngineDeviceCUDA &Jit;
101 DispatcherCUDA() : Jit(JitEngineDeviceCUDA::instance()) {
102 TargetModel = TargetModelType::CUDA;
103 }
104 MemoryCache<CUfunction> CodeCache{"DispatcherCUDA"};
105 StorageCache ObjectCache{"DispatcherCUDA"};
106};
107
108} // namespace proteus
109
110#endif
111
112#endif // PROTEUS_FRONTEND_DISPATCHER_CUDA_HPP
void char * KernelName
Definition CompilerInterfaceDevice.cpp:52
auto & Jit
Definition CompilerInterfaceDevice.cpp:56
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition StorageCache.cpp:24
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:56
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.hpp:28
Definition Dispatcher.hpp:16
unsigned Z
Definition Dispatcher.hpp:17
unsigned Y
Definition Dispatcher.hpp:17
unsigned X
Definition Dispatcher.hpp:17