Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
DispatcherHIP.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_HIP_HPP
2#define PROTEUS_FRONTEND_DISPATCHER_HIP_HPP
3
4#if PROTEUS_ENABLE_HIP
5
6#include "proteus/Error.h"
9
10namespace proteus {
11
12class DispatcherHIP : public Dispatcher {
13public:
14 static DispatcherHIP &instance() {
15 static DispatcherHIP D;
16 return D;
17 }
18
19 std::unique_ptr<MemoryBuffer>
20 compile([[maybe_unused]] std::unique_ptr<LLVMContext> Ctx,
21 std::unique_ptr<Module> M, HashT ModuleHash) override {
22 std::unique_ptr<MemoryBuffer> ObjectModule = Jit.compileOnly(*M);
23 if (!ObjectModule)
24 PROTEUS_FATAL_ERROR("Expected non-null object library");
25
26 StorageCache.store(ModuleHash, ObjectModule->getMemBufferRef());
27
28 return ObjectModule;
29 }
30
31 std::unique_ptr<MemoryBuffer> lookupObjectModule(HashT ModuleHash) override {
32 return StorageCache.lookup(ModuleHash);
33 }
34
35 DispatchResult launch(StringRef KernelName, LaunchDims GridDim,
36 LaunchDims BlockDim, ArrayRef<void *> KernelArgs,
37 uint64_t ShmemSize, void *Stream,
38 std::optional<MemoryBufferRef> ObjectModule) override {
39 auto *KernelFunc = getFunctionAddress(KernelName, ObjectModule);
40
41 dim3 HipGridDim = {GridDim.X, GridDim.Y, GridDim.Z};
42 dim3 HipBlockDim = {BlockDim.X, BlockDim.Y, BlockDim.Z};
43 hipStream_t HipStream = reinterpret_cast<hipStream_t>(Stream);
44
45 void **KernelArgsPtrs = const_cast<void **>(KernelArgs.data());
47 reinterpret_cast<hipFunction_t>(KernelFunc), HipGridDim, HipBlockDim,
48 KernelArgsPtrs, ShmemSize, HipStream);
49 }
50
51 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
52 LaunchDims BlockDim, ArrayRef<void *> KernelArgs,
53 uint64_t ShmemSize, void *Stream) override {
54 dim3 HipGridDim = {GridDim.X, GridDim.Y, GridDim.Z};
55 dim3 HipBlockDim = {BlockDim.X, BlockDim.Y, BlockDim.Z};
56 hipStream_t HipStream = reinterpret_cast<hipStream_t>(Stream);
57
58 void **KernelArgsPtrs = const_cast<void **>(KernelArgs.data());
60 reinterpret_cast<hipFunction_t>(KernelFunc), HipGridDim, HipBlockDim,
61 KernelArgsPtrs, ShmemSize, HipStream);
62 }
63
64 StringRef getTargetArch() const override { return Jit.getDeviceArch(); }
65
66 ~DispatcherHIP() {
67 CodeCache.printStats();
68 StorageCache.printStats();
69 }
70
71 void *
72 getFunctionAddress(StringRef KernelName,
73 std::optional<MemoryBufferRef> ObjectModule) override {
74 auto GetKernelFunc = [&]() {
75 // Hash the kernel name to get a unique id.
76 HashT HashValue = hash(KernelName);
77
78 if (auto KernelFunc = CodeCache.lookup(HashValue))
79 return KernelFunc;
80
82 KernelName, ObjectModule->getBufferStart(),
83 /*RelinkGlobalsByCopy*/ false,
84 /* VarNameToDevPtr */ {});
85
86 CodeCache.insert(HashValue, KernelFunc, KernelName);
87
88 return KernelFunc;
89 };
90
91 auto KernelFunc = GetKernelFunc();
92 return KernelFunc;
93 }
94
95private:
96 JitEngineDeviceHIP &Jit;
97 DispatcherHIP() : Jit(JitEngineDeviceHIP::instance()) {
98 TargetModel = TargetModelType::HIP;
99 }
100 JitCache<hipFunction_t> CodeCache;
101 JitStorageCache<hipFunction_t> StorageCache;
102};
103
104} // namespace proteus
105
106#endif
107
108#endif // PROTEUS_FRONTEND_DISPATCHER_HIP_HPP
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
auto & Jit
Definition CompilerInterfaceDevice.cpp:54
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition CppJitModule.cpp:21
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:51
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, const void * > &VarNameToDevPtr)
Definition CoreDeviceCUDA.hpp:27
Definition Dispatcher.hpp:15
unsigned Z
Definition Dispatcher.hpp:16
unsigned Y
Definition Dispatcher.hpp:16
unsigned X
Definition Dispatcher.hpp:16