Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
DispatcherHIP.h
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_HIP_H
2#define PROTEUS_FRONTEND_DISPATCHER_HIP_H
3
4#if PROTEUS_ENABLE_HIP
5
6#include "proteus/Error.h"
12
13#include <llvm/Bitcode/BitcodeReader.h>
14#include <llvm/Linker/Linker.h>
15#include <llvm/Support/FileSystem.h>
16#include <llvm/Support/MemoryBuffer.h>
17
18namespace proteus {
19
20class DispatcherHIP : public Dispatcher {
21public:
22 static DispatcherHIP &instance() {
23 static DispatcherHIP D;
24 return D;
25 }
26
27 std::unique_ptr<MemoryBuffer> compile(std::unique_ptr<LLVMContext> Ctx,
28 std::unique_ptr<Module> Mod,
29 const HashT &ModuleHash,
30 bool DisableIROpt = false) override {
31 TIMESCOPE(DispatcherHIP, compile);
32 // This is necessary to ensure Ctx outlives M. Setting [[maybe_unused]] can
33 // trigger a lifetime bug.
34 auto CtxOwner = std::move(Ctx);
35 auto ModOwner = std::move(Mod);
36 const auto &Toolchain = resolveHIPToolchain();
37
38 auto LoadBitcode = [&](const llvm::SmallString<256> &Path) {
39 auto BufferOrErr = llvm::MemoryBuffer::getFile(Path);
40 if (!BufferOrErr || !BufferOrErr.get())
41 reportFatalError("DispatchHIP: failed to read ROCm bitcode file: " +
42 Path.str().str() + " (" + Toolchain.Origin + ")");
43 auto Parsed = llvm::parseBitcodeFile(
44 BufferOrErr->get()->getMemBufferRef(), ModOwner->getContext());
45 if (!Parsed)
46 reportFatalError("DispatchHIP: failed to parse ROCm bitcode file: " +
47 Path.str().str() + " (" + Toolchain.Origin + ")");
48 return std::move(Parsed.get());
49 };
50
51 auto AppendBitcodePath =
52 [&](llvm::SmallVectorImpl<llvm::SmallString<256>> &Paths,
53 llvm::StringRef Filename) {
54 llvm::SmallString<256> Path{Toolchain.DeviceLibDir};
55 llvm::sys::path::append(Path, Filename);
56 Paths.push_back(std::move(Path));
57 };
58
59 auto Exists = [&](llvm::StringRef Filename) -> bool {
60 llvm::SmallString<256> Path{Toolchain.DeviceLibDir};
61 llvm::sys::path::append(Path, Filename);
62 return llvm::sys::fs::exists(Path);
63 };
64
65 auto PickFirstExisting =
66 [&](std::initializer_list<llvm::StringRef> Candidates)
67 -> llvm::StringRef {
68 for (auto C : Candidates) {
69 if (Exists(C))
70 return C;
71 }
72 return {};
73 };
74
75 // Link ROCm device libraries (ocml/ockl + oclc config) so HIPRTC can
76 // resolve __ocml_* calls produced by math lowering.
77 llvm::SmallVector<llvm::SmallString<256>, 8> LibsToLink;
78 AppendBitcodePath(LibsToLink, "ocml.bc");
79 AppendBitcodePath(LibsToLink, "ockl.bc");
80
81 // ABI: prefer the newest available.
82 if (auto Abi = PickFirstExisting({"oclc_abi_version_600.bc",
83 "oclc_abi_version_500.bc",
84 "oclc_abi_version_400.bc"});
85 !Abi.empty()) {
86 AppendBitcodePath(LibsToLink, Abi);
87 } else {
89 std::string("DispatchHIP: missing oclc ABI bitcode under ") +
90 Toolchain.DeviceLibDir + " (" + Toolchain.Origin +
91 "; expected oclc_abi_version_{600,500,400}.bc)");
92 }
93
94 // ISA: derived from device arch like "gfx90a" -> "90a".
95 const std::string DeviceArch = Jit.getDeviceArch().str();
96 if (!llvm::StringRef{DeviceArch}.starts_with("gfx"))
97 reportFatalError("DispatchHIP: unexpected HIP device arch: " +
98 DeviceArch);
99 const llvm::StringRef IsaSuffix = llvm::StringRef{DeviceArch}.drop_front(3);
100 const std::string IsaFile = ("oclc_isa_version_" + IsaSuffix + ".bc").str();
101 if (!Exists(IsaFile))
102 reportFatalError(std::string("DispatchHIP: missing ISA bitcode file ") +
103 IsaFile + " under " + Toolchain.DeviceLibDir + " (" +
104 Toolchain.Origin + "; DeviceArch=" + DeviceArch + ")");
105 AppendBitcodePath(LibsToLink, IsaFile);
106
107 // Math/FP mode defaults (safe defaults, can be revisited later).
108 AppendBitcodePath(LibsToLink, "oclc_unsafe_math_off.bc");
109 AppendBitcodePath(LibsToLink, "oclc_finite_only_off.bc");
110 AppendBitcodePath(LibsToLink, "oclc_daz_opt_off.bc");
111 AppendBitcodePath(LibsToLink, "oclc_correctly_rounded_sqrt_on.bc");
112
113 // Wavefront size selection: RDNA is typically wave32; CDNA/gfx9 wave64.
114 const bool IsWave32 = llvm::StringRef{DeviceArch}.starts_with("gfx10") ||
115 llvm::StringRef{DeviceArch}.starts_with("gfx11") ||
116 llvm::StringRef{DeviceArch}.starts_with("gfx12");
117 AppendBitcodePath(LibsToLink, IsWave32 ? "oclc_wavefrontsize64_off.bc"
118 : "oclc_wavefrontsize64_on.bc");
119
120 llvm::Linker Linker{*ModOwner};
121 for (const auto &Path : LibsToLink) {
122 auto LibMod = LoadBitcode(Path);
123 Linker.linkInModule(std::move(LibMod),
124 llvm::Linker::Flags::LinkOnlyNeeded);
125 }
126
127 std::unique_ptr<MemoryBuffer> ObjectModule =
128 Jit.compileOnly(*ModOwner, DisableIROpt);
129 if (!ObjectModule)
130 reportFatalError("Expected non-null object library");
131
132 ObjectCache->store(
133 ModuleHash, CacheEntry::staticObject(ObjectModule->getMemBufferRef()));
134
135 return ObjectModule;
136 }
137
138 std::unique_ptr<CompiledLibrary>
139 lookupCompiledLibrary(const HashT &ModuleHash) override {
140 return ObjectCache->lookup(ModuleHash);
141 }
142
143 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
144 LaunchDims BlockDim, void *KernelArgs[],
145 uint64_t ShmemSize, void *Stream) override {
146 TIMESCOPE(DispatcherHIP, launch);
147 dim3 HipGridDim = {GridDim.X, GridDim.Y, GridDim.Z};
148 dim3 HipBlockDim = {BlockDim.X, BlockDim.Y, BlockDim.Z};
149 hipStream_t HipStream = reinterpret_cast<hipStream_t>(Stream);
150
152 reinterpret_cast<hipFunction_t>(KernelFunc), HipGridDim, HipBlockDim,
153 KernelArgs, ShmemSize, HipStream);
154 }
155
156 StringRef getDeviceArch() const override { return Jit.getDeviceArch(); }
157
158 ~DispatcherHIP() {
159 CodeCache.printStats();
160 CodeCache.printKernelTrace();
161 ObjectCache->printStats();
162 }
163
164 void *getFunctionAddress(const std::string &KernelName,
165 const HashT &ModuleHash,
166 CompiledLibrary &Library) override {
167 TIMESCOPE(DispatcherHIP, getFunctionAddress);
168 auto GetKernelFunc = [&]() {
169 // Hash the kernel name to get a unique id.
170 HashT HashValue = hash(KernelName, ModuleHash);
171
172 if (auto KernelFunc = CodeCache.lookup(HashValue))
173 return KernelFunc;
174
175 auto KernelFunc = proteus::getKernelFunctionFromImage(
176 KernelName, Library.ObjectModule->getBufferStart(),
177 /*RelinkGlobalsByCopy*/ false,
178 /* VarNameToGlobalInfo */ {});
179
180 CodeCache.insert(HashValue, KernelFunc, KernelName);
181
182 return KernelFunc;
183 };
184
185 auto KernelFunc = GetKernelFunc();
186 return KernelFunc;
187 }
188
189 void registerDynamicLibrary(const HashT &, const std::string &) override {
190 reportFatalError("Dispatch HIP does not support registerDynamicLibrary");
191 }
192
193 void registerObject(const HashT &HashValue,
194 const llvm::MemoryBufferRef &Obj) override {
195 ObjectCache->store(HashValue, CacheEntry::staticObject(Obj));
196 }
197
198private:
199 JitEngineDeviceHIP &Jit;
200 DispatcherHIP()
201 : Dispatcher("DispatcherHIP", TargetModelType::HIP),
202 Jit(JitEngineDeviceHIP::instance()) {}
203 MemoryCache<hipFunction_t> CodeCache{"DispatcherHIP"};
204};
205
206} // namespace proteus
207
208#endif
209
210#endif // PROTEUS_FRONTEND_DISPATCHER_HIP_H
void char * KernelName
Definition CompilerInterfaceDevice.cpp:59
JitEngineHost & Jit
Definition CompilerInterfaceHost.cpp:26
#define TIMESCOPE(...)
Definition TimeTracing.h:66
Definition MemoryCache.h:27
TargetModelType
Definition TargetModel.h:8
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.h:168
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:78
const ResolvedHIPToolchain & resolveHIPToolchain()
Definition HIPToolchain.cpp:273
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.h:50
Definition Dispatcher.h:22
unsigned Z
Definition Dispatcher.h:23
unsigned Y
Definition Dispatcher.h:23
unsigned X
Definition Dispatcher.h:23