Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
DispatcherCUDA.h
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_CUDA_H
2#define PROTEUS_FRONTEND_DISPATCHER_CUDA_H
3
4#if PROTEUS_ENABLE_CUDA
5
6#include "proteus/Error.h"
12
13#include <llvm/Bitcode/BitcodeReader.h>
14#include <llvm/Linker/Linker.h>
15#include <llvm/Support/MemoryBuffer.h>
16
17namespace proteus {
18
19class DispatcherCUDA : public Dispatcher {
20public:
21 static DispatcherCUDA &instance() {
22 static DispatcherCUDA D;
23 return D;
24 }
25
26 std::unique_ptr<MemoryBuffer> compile(std::unique_ptr<LLVMContext> Ctx,
27 std::unique_ptr<Module> Mod,
28 const HashT &ModuleHash,
29 bool DisableIROpt = false) override {
30 TIMESCOPE(DispatcherCUDA, compile);
31 // This is necessary to ensure Ctx outlives M. Setting [[maybe_unused]] can
32 // trigger a lifetime bug.
33 auto CtxOwner = std::move(Ctx);
34 auto ModOwner = std::move(Mod);
35
36 const auto &Toolchain = resolveCUDAToolchain();
37 auto LibDeviceBuffer = llvm::MemoryBuffer::getFile(Toolchain.LibDevicePath);
38 if (!LibDeviceBuffer || !LibDeviceBuffer.get())
39 reportFatalError("DispatchCUDA: failed to read libdevice from " +
40 Toolchain.LibDevicePath + " (" + Toolchain.Origin + ")");
41
42 auto LibDeviceModule = llvm::parseBitcodeFile(
43 LibDeviceBuffer->get()->getMemBufferRef(), ModOwner->getContext());
44 if (!LibDeviceModule)
45 reportFatalError("DispatchCUDA: failed to parse libdevice from " +
46 Toolchain.LibDevicePath + " (" + Toolchain.Origin + ")");
47
48 llvm::Linker linker(*ModOwner);
49 linker.linkInModule(std::move(LibDeviceModule.get()),
50 llvm::Linker::Flags::LinkOnlyNeeded);
51
52 std::unique_ptr<MemoryBuffer> ObjectModule =
53 Jit.compileOnly(*ModOwner, DisableIROpt);
54 if (!ObjectModule)
55 reportFatalError("Expected non-null object library");
56
57 ObjectCache->store(
58 ModuleHash, CacheEntry::staticObject(ObjectModule->getMemBufferRef()));
59
60 return ObjectModule;
61 }
62
63 std::unique_ptr<CompiledLibrary>
64 lookupCompiledLibrary(const HashT &ModuleHash) override {
65 return ObjectCache->lookup(ModuleHash);
66 }
67
68 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
69 LaunchDims BlockDim, void *KernelArgs[],
70 uint64_t ShmemSize, void *Stream) override {
71 TIMESCOPE(DispatcherCUDA, launch);
72 dim3 CudaGridDim = {GridDim.X, GridDim.Y, GridDim.Z};
73 dim3 CudaBlockDim = {BlockDim.X, BlockDim.Y, BlockDim.Z};
74 cudaStream_t CudaStream = reinterpret_cast<cudaStream_t>(Stream);
75
77 reinterpret_cast<cudaFunction_t>(KernelFunc), CudaGridDim, CudaBlockDim,
78 KernelArgs, ShmemSize, CudaStream);
79 }
80
81 StringRef getDeviceArch() const override { return Jit.getDeviceArch(); }
82
83 void *getFunctionAddress(const std::string &KernelName,
84 const HashT &ModuleHash,
85 CompiledLibrary &Library) override {
86 TIMESCOPE(DispatcherCUDA, getFunctionAddress);
87 auto GetKernelFunc = [&]() {
88 // Hash the kernel name to get a unique id.
89 HashT HashValue = hash(KernelName, ModuleHash);
90
91 if (auto KernelFunc = CodeCache.lookup(HashValue))
92 return KernelFunc;
93
95 KernelName, Library.ObjectModule->getBufferStart(),
96 /*RelinkGlobalsByCopy*/ false,
97 /* VarNameToGlobalInfo */ {});
98
99 CodeCache.insert(HashValue, KernelFunc, KernelName);
100
101 return KernelFunc;
102 };
103
104 auto KernelFunc = GetKernelFunc();
105 return KernelFunc;
106 }
107
108 void registerDynamicLibrary(const HashT &, const std::string &) override {
109 reportFatalError("Dispatch CUDA does not support registerDynamicLibrary");
110 }
111
112 void registerObject(const HashT &HashValue,
113 const llvm::MemoryBufferRef &Obj) override {
114 ObjectCache->store(HashValue, CacheEntry::staticObject(Obj));
115 }
116
117 ~DispatcherCUDA() {
118 CodeCache.printStats();
119 CodeCache.printKernelTrace();
120 ObjectCache->printStats();
121 }
122
123private:
124 JitEngineDeviceCUDA &Jit;
125 DispatcherCUDA()
126 : Dispatcher("DispatcherCUDA", TargetModelType::CUDA),
127 Jit(JitEngineDeviceCUDA::instance()) {}
128 MemoryCache<CUfunction> CodeCache{"DispatcherCUDA"};
129};
130
131} // namespace proteus
132
133#endif
134
135#endif // PROTEUS_FRONTEND_DISPATCHER_CUDA_H
void char * KernelName
Definition CompilerInterfaceDevice.cpp:59
JitEngineHost & Jit
Definition CompilerInterfaceHost.cpp:26
#define TIMESCOPE(...)
Definition TimeTracing.h:66
Definition MemoryCache.h:27
TargetModelType
Definition TargetModel.h:8
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.h:168
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:78
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.h:50
const ResolvedCUDAToolchain & resolveCUDAToolchain()
Definition CUDAToolchain.cpp:245
Definition Dispatcher.h:22
unsigned Z
Definition Dispatcher.h:23
unsigned Y
Definition Dispatcher.h:23
unsigned X
Definition Dispatcher.h:23