Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
DispatcherHIP.h
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_HIP_H
2#define PROTEUS_FRONTEND_DISPATCHER_HIP_H
3
4#if PROTEUS_ENABLE_HIP
5
6#include "proteus/Error.h"
11
12#include <llvm/Bitcode/BitcodeReader.h>
13#include <llvm/Linker/Linker.h>
14#include <llvm/Support/FileSystem.h>
15#include <llvm/Support/MemoryBuffer.h>
16
17namespace proteus {
18
19class DispatcherHIP : public Dispatcher {
20public:
21 static DispatcherHIP &instance() {
22 static DispatcherHIP D;
23 return D;
24 }
25
26 std::unique_ptr<MemoryBuffer> compile(std::unique_ptr<LLVMContext> Ctx,
27 std::unique_ptr<Module> Mod,
28 const HashT &ModuleHash,
29 bool DisableIROpt = false) override {
30 TIMESCOPE(DispatcherHIP, compile);
31 // This is necessary to ensure Ctx outlives M. Setting [[maybe_unused]] can
32 // trigger a lifetime bug.
33 auto CtxOwner = std::move(Ctx);
34 auto ModOwner = std::move(Mod);
35
36 auto LoadBitcode = [&](const llvm::SmallString<256> &Path) {
37 auto BufferOrErr = llvm::MemoryBuffer::getFile(Path);
38 if (!BufferOrErr || !BufferOrErr.get())
39 reportFatalError("DispatchHIP: failed to read ROCm bitcode file: " +
40 Path.str().str());
41 auto Parsed = llvm::parseBitcodeFile(
42 BufferOrErr->get()->getMemBufferRef(), ModOwner->getContext());
43 if (!Parsed)
44 reportFatalError("DispatchHIP: failed to parse ROCm bitcode file: " +
45 Path.str().str());
46 return std::move(Parsed.get());
47 };
48
49 auto AppendBitcodePath =
50 [&](llvm::SmallVectorImpl<llvm::SmallString<256>> &Paths,
51 llvm::StringRef Filename) {
52 llvm::SmallString<256> Path{PROTEUS_ROCM_BITCODE_DIR};
53 llvm::sys::path::append(Path, Filename);
54 Paths.push_back(std::move(Path));
55 };
56
57 auto Exists = [&](llvm::StringRef Filename) -> bool {
58 llvm::SmallString<256> Path{PROTEUS_ROCM_BITCODE_DIR};
59 llvm::sys::path::append(Path, Filename);
60 return llvm::sys::fs::exists(Path);
61 };
62
63 auto PickFirstExisting =
64 [&](std::initializer_list<llvm::StringRef> Candidates)
65 -> llvm::StringRef {
66 for (auto C : Candidates) {
67 if (Exists(C))
68 return C;
69 }
70 return {};
71 };
72
73 // Link ROCm device libraries (ocml/ockl + oclc config) so HIPRTC can
74 // resolve __ocml_* calls produced by math lowering.
75 llvm::SmallVector<llvm::SmallString<256>, 8> LibsToLink;
76 AppendBitcodePath(LibsToLink, "ocml.bc");
77 AppendBitcodePath(LibsToLink, "ockl.bc");
78
79 // ABI: prefer the newest available.
80 if (auto Abi = PickFirstExisting({"oclc_abi_version_600.bc",
81 "oclc_abi_version_500.bc",
82 "oclc_abi_version_400.bc"});
83 !Abi.empty()) {
84 AppendBitcodePath(LibsToLink, Abi);
85 } else {
87 std::string("DispatchHIP: missing oclc ABI bitcode under ") +
88 PROTEUS_ROCM_BITCODE_DIR +
89 " (expected oclc_abi_version_{600,500,400}.bc)");
90 }
91
92 // ISA: derived from device arch like "gfx90a" -> "90a".
93 const std::string DeviceArch = Jit.getDeviceArch().str();
94 if (!llvm::StringRef{DeviceArch}.starts_with("gfx"))
95 reportFatalError("DispatchHIP: unexpected HIP device arch: " +
96 DeviceArch);
97 const llvm::StringRef IsaSuffix = llvm::StringRef{DeviceArch}.drop_front(3);
98 const std::string IsaFile = ("oclc_isa_version_" + IsaSuffix + ".bc").str();
99 if (!Exists(IsaFile))
100 reportFatalError(std::string("DispatchHIP: missing ISA bitcode file ") +
101 IsaFile + " under " + PROTEUS_ROCM_BITCODE_DIR +
102 " (DeviceArch=" + DeviceArch + ")");
103 AppendBitcodePath(LibsToLink, IsaFile);
104
105 // Math/FP mode defaults (safe defaults, can be revisited later).
106 AppendBitcodePath(LibsToLink, "oclc_unsafe_math_off.bc");
107 AppendBitcodePath(LibsToLink, "oclc_finite_only_off.bc");
108 AppendBitcodePath(LibsToLink, "oclc_daz_opt_off.bc");
109 AppendBitcodePath(LibsToLink, "oclc_correctly_rounded_sqrt_on.bc");
110
111 // Wavefront size selection: RDNA is typically wave32; CDNA/gfx9 wave64.
112 const bool IsWave32 = llvm::StringRef{DeviceArch}.starts_with("gfx10") ||
113 llvm::StringRef{DeviceArch}.starts_with("gfx11") ||
114 llvm::StringRef{DeviceArch}.starts_with("gfx12");
115 AppendBitcodePath(LibsToLink, IsWave32 ? "oclc_wavefrontsize64_off.bc"
116 : "oclc_wavefrontsize64_on.bc");
117
118 llvm::Linker Linker{*ModOwner};
119 for (const auto &Path : LibsToLink) {
120 auto LibMod = LoadBitcode(Path);
121 Linker.linkInModule(std::move(LibMod),
122 llvm::Linker::Flags::LinkOnlyNeeded);
123 }
124
125 std::unique_ptr<MemoryBuffer> ObjectModule =
126 Jit.compileOnly(*ModOwner, DisableIROpt);
127 if (!ObjectModule)
128 reportFatalError("Expected non-null object library");
129
130 ObjectCache->store(
131 ModuleHash, CacheEntry::staticObject(ObjectModule->getMemBufferRef()));
132
133 return ObjectModule;
134 }
135
136 std::unique_ptr<CompiledLibrary>
137 lookupCompiledLibrary(const HashT &ModuleHash) override {
138 return ObjectCache->lookup(ModuleHash);
139 }
140
141 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
142 LaunchDims BlockDim, void *KernelArgs[],
143 uint64_t ShmemSize, void *Stream) override {
144 TIMESCOPE(DispatcherHIP, launch);
145 dim3 HipGridDim = {GridDim.X, GridDim.Y, GridDim.Z};
146 dim3 HipBlockDim = {BlockDim.X, BlockDim.Y, BlockDim.Z};
147 hipStream_t HipStream = reinterpret_cast<hipStream_t>(Stream);
148
150 reinterpret_cast<hipFunction_t>(KernelFunc), HipGridDim, HipBlockDim,
151 KernelArgs, ShmemSize, HipStream);
152 }
153
154 StringRef getDeviceArch() const override { return Jit.getDeviceArch(); }
155
156 ~DispatcherHIP() {
157 CodeCache.printStats();
158 CodeCache.printKernelTrace();
159 ObjectCache->printStats();
160 }
161
162 void *getFunctionAddress(const std::string &KernelName,
163 const HashT &ModuleHash,
164 CompiledLibrary &Library) override {
165 TIMESCOPE(DispatcherHIP, getFunctionAddress);
166 auto GetKernelFunc = [&]() {
167 // Hash the kernel name to get a unique id.
168 HashT HashValue = hash(KernelName, ModuleHash);
169
170 if (auto KernelFunc = CodeCache.lookup(HashValue))
171 return KernelFunc;
172
173 auto KernelFunc = proteus::getKernelFunctionFromImage(
174 KernelName, Library.ObjectModule->getBufferStart(),
175 /*RelinkGlobalsByCopy*/ false,
176 /* VarNameToGlobalInfo */ {});
177
178 CodeCache.insert(HashValue, KernelFunc, KernelName);
179
180 return KernelFunc;
181 };
182
183 auto KernelFunc = GetKernelFunc();
184 return KernelFunc;
185 }
186
187 void registerDynamicLibrary(const HashT &, const std::string &) override {
188 reportFatalError("Dispatch HIP does not support registerDynamicLibrary");
189 }
190
191 void registerObject(const HashT &HashValue,
192 const llvm::MemoryBufferRef &Obj) override {
193 ObjectCache->store(HashValue, CacheEntry::staticObject(Obj));
194 }
195
196private:
197 JitEngineDeviceHIP &Jit;
198 DispatcherHIP()
199 : Dispatcher("DispatcherHIP", TargetModelType::HIP),
200 Jit(JitEngineDeviceHIP::instance()) {}
201 MemoryCache<hipFunction_t> CodeCache{"DispatcherHIP"};
202};
203
204} // namespace proteus
205
206#endif
207
208#endif // PROTEUS_FRONTEND_DISPATCHER_HIP_H
void char * KernelName
Definition CompilerInterfaceDevice.cpp:55
JitEngineHost & Jit
Definition CompilerInterfaceHost.cpp:26
#define TIMESCOPE(...)
Definition TimeTracing.h:66
Definition MemoryCache.h:27
TargetModelType
Definition TargetModel.h:8
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.h:168
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:78
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.h:50
Definition Dispatcher.h:22
unsigned Z
Definition Dispatcher.h:23
unsigned Y
Definition Dispatcher.h:23
unsigned X
Definition Dispatcher.h:23