Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CoreLLVMDevice.h
Go to the documentation of this file.
1#ifndef PROTEUS_CORE_LLVM_DEVICE_H
2#define PROTEUS_CORE_LLVM_DEVICE_H
3
4#if PROTEUS_ENABLE_HIP
6#endif
7
8#if PROTEUS_ENABLE_CUDA
10#endif
11
12#if defined(PROTEUS_ENABLE_HIP) || defined(PROTEUS_ENABLE_CUDA)
13
20
21#include <llvm/Analysis/CallGraph.h>
22#include <llvm/Bitcode/BitcodeReader.h>
23#include <llvm/Bitcode/BitcodeWriter.h>
24#include <llvm/IR/Attributes.h>
25#include <llvm/IR/ConstantRange.h>
26#include <llvm/IR/ReplaceConstant.h>
27#include <llvm/IR/Verifier.h>
28#include <llvm/Object/ELFObjectFile.h>
29#include <llvm/Transforms/Utils/Cloning.h>
30
31namespace proteus {
32
33inline void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
34 auto ReplaceIntrinsicDim = [&](ArrayRef<StringRef> IntrinsicNames,
35 uint32_t DimValue) {
36 auto CollectCallUsers = [](Function &F) {
37 SmallVector<CallInst *> CallUsers;
38 for (auto *User : F.users()) {
39 auto *Call = dyn_cast<CallInst>(User);
40 if (!Call)
41 continue;
42 CallUsers.push_back(Call);
43 }
44
45 return CallUsers;
46 };
47
48 for (auto IntrinsicName : IntrinsicNames) {
49
50 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
51 if (!IntrinsicFunction)
52 continue;
53
54 auto TraceOut = [](Function *F, Value *C) {
55 SmallString<128> S;
56 raw_svector_ostream OS(S);
57 OS << "[DimSpec] Replace call to " << F->getName() << " with constant "
58 << *C << "\n";
59
60 return S;
61 };
62
63 for (auto *Call : CollectCallUsers(*IntrinsicFunction)) {
64 Value *ConstantValue =
65 ConstantInt::get(Type::getInt32Ty(M.getContext()), DimValue);
66 Call->replaceAllUsesWith(ConstantValue);
67 if (Config::get().traceSpecializations())
68 Logger::trace(TraceOut(IntrinsicFunction, ConstantValue));
69 Call->eraseFromParent();
70 }
71 }
72 };
73
74 ReplaceIntrinsicDim(detail::gridDimXFnName(), GridDim.x);
75 ReplaceIntrinsicDim(detail::gridDimYFnName(), GridDim.y);
76 ReplaceIntrinsicDim(detail::gridDimZFnName(), GridDim.z);
77
78 ReplaceIntrinsicDim(detail::blockDimXFnName(), BlockDim.x);
79 ReplaceIntrinsicDim(detail::blockDimYFnName(), BlockDim.y);
80 ReplaceIntrinsicDim(detail::blockDimZFnName(), BlockDim.z);
81}
82
83inline void setKernelDimsRange(Module &M, dim3 &GridDim, dim3 &BlockDim) {
84 auto AttachRange = [&](ArrayRef<StringRef> IntrinsicNames,
85 uint32_t DimValue) {
86 if (DimValue == 0) {
87 reportFatalError("Dimension value cannot be zero");
88 }
89
90 for (auto IntrinsicName : IntrinsicNames) {
91 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
92 if (!IntrinsicFunction || IntrinsicFunction->use_empty())
93 continue;
94
95 auto TraceOut = [](Function *IntrinsicF, uint32_t DimValue) {
96 SmallString<128> S;
97 raw_svector_ostream OS(S);
98 OS << "[DimSpec] Range " << IntrinsicF->getName() << " [0," << DimValue
99 << ")\n";
100 return S;
101 };
102
103 for (auto *U : IntrinsicFunction->users()) {
104 auto *Call = dyn_cast<CallInst>(U);
105 if (!Call)
106 continue;
107
108 auto *RetTy = dyn_cast<IntegerType>(Call->getType());
109 if (!RetTy)
110 continue;
111
112 unsigned BitWidth = RetTy->getBitWidth();
113 ConstantRange Range(APInt(BitWidth, 0), APInt(BitWidth, DimValue));
114
115#if LLVM_VERSION_MAJOR >= 19
116 AttrBuilder Builder{M.getContext()};
117 Builder.addRangeAttr(Range);
118 Call->removeRetAttr(Attribute::Range);
119 Call->setAttributes(
120 Call->getAttributes().addRetAttributes(M.getContext(), Builder));
121#else
122 // LLVM 18 (ROCm 6.2.x) does not expose the Range attribute; use range
123 // metadata instead.
124 LLVMContext &Ctx = M.getContext();
125 Metadata *RangeMD[] = {
126 ConstantAsMetadata::get(ConstantInt::get(RetTy, 0)),
127 ConstantAsMetadata::get(ConstantInt::get(RetTy, DimValue))};
128 MDNode *RangeNode = MDNode::get(Ctx, RangeMD);
129 Call->setMetadata(LLVMContext::MD_range, RangeNode);
130#endif
131
132 if (Config::get().traceSpecializations())
133 Logger::trace(TraceOut(IntrinsicFunction, DimValue));
134 }
135 }
136 };
137
138 AttachRange(detail::threadIdxXFnName(), BlockDim.x);
139 AttachRange(detail::threadIdxYFnName(), BlockDim.y);
140 AttachRange(detail::threadIdxZFnName(), BlockDim.z);
141
142 AttachRange(detail::blockIdxXFnName(), GridDim.x);
143 AttachRange(detail::blockIdxYFnName(), GridDim.y);
144 AttachRange(detail::blockIdxZFnName(), GridDim.z);
145}
146
147inline void replaceGlobalVariablesWithPointers(
148 Module &M,
149 const std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo) {
150 TIMESCOPE("proteus::replaceGlobalVariablesWithPointers");
151 // Re-link globals to fixed addresses provided by registered
152 // variables.
153 for (auto RegisterVar : VarNameToGlobalInfo) {
154 auto &VarName = RegisterVar.first;
155 auto *GV = M.getNamedGlobal(VarName);
156 // Skip linking if the GV does not exist in the module.
157 if (!GV)
158 continue;
159
160 // This will convert constant users of GV to instructions so that we can
161 // replace with the GV ptr.
162 convertUsersOfConstantsToInstructions({GV});
163
164 Constant *Addr =
165 ConstantInt::get(Type::getInt64Ty(M.getContext()), 0xDEADBEEFDEADBEEF);
166 auto *CE = ConstantExpr::getIntToPtr(
167 Addr, PointerType::get(M.getContext(), GV->getAddressSpace()));
168 auto *GVarPtr = new GlobalVariable(
169 M, PointerType::get(M.getContext(), GV->getAddressSpace()), false,
170 GlobalValue::ExternalLinkage, CE, GV->getName() + "$ptr", nullptr,
171 GV->getThreadLocalMode(), GV->getAddressSpace(), true);
172
173 // Find all Constant users that refer to the global variable.
174 SmallPtrSet<Value *, 16> ValuesToReplace;
175 SmallVector<Value *> Worklist;
176 // Seed with the global variable.
177 Worklist.push_back(GV);
178 ValuesToReplace.insert(GV);
179 while (!Worklist.empty()) {
180 Value *V = Worklist.pop_back_val();
181 for (auto *User : V->users()) {
182 if (auto *C = dyn_cast<Constant>(User)) {
183 if (ValuesToReplace.insert(C).second)
184 Worklist.push_back(C);
185
186 continue;
187 }
188
189 // Skip instructions to be handled when replacing.
190 if (isa<Instruction>(User))
191 continue;
192
193 reportFatalError("Expected Instruction or Constant user for Value: " +
194 toString(*V) + " , User: " + toString(*User));
195 }
196 }
197
198 for (Value *V : ValuesToReplace) {
199 SmallPtrSet<Instruction *, 16> Insts;
200 // Find instruction users to replace value.
201 for (User *U : V->users()) {
202 if (auto *I = dyn_cast<Instruction>(U)) {
203 Insts.insert(I);
204 }
205 }
206
207 // Replace value in instructions.
208 for (auto *I : Insts) {
209 IRBuilder Builder{I};
210 auto *Load = Builder.CreateLoad(GV->getType(), GVarPtr);
211 Value *Replacement = Load;
212 Type *ExpectedTy = V->getType();
213 if (Load->getType() != ExpectedTy)
214 Replacement =
215 Builder.CreatePointerBitCastOrAddrSpaceCast(Load, ExpectedTy);
216
217 I->replaceUsesOfWith(V, Replacement);
218 }
219 }
220 }
221
222 if (Config::get().ProteusDebugOutput) {
223 if (verifyModule(M, &errs()))
224 reportFatalError("Broken module found, JIT compilation aborted!");
225 }
226}
227
228inline void relinkGlobalsObject(
229 MemoryBufferRef Object,
230 const std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo) {
231 TIMESCOPE("proteus::relinkGlobalsObject");
232 auto ObjOrErr = object::ObjectFile::createObjectFile(Object);
233 if (auto E = ObjOrErr.takeError())
234 reportFatalError("Cannot create the device elf: " + toString(std::move(E)));
235 auto &DeviceObjFile = **ObjOrErr;
236 auto *DeviceElf = dyn_cast<object::ELFObjectFileBase>(&DeviceObjFile);
237 if (!DeviceElf)
238 reportFatalError("Expected ELF object file");
239
240 for (auto &[GlobalName, GVI] : VarNameToGlobalInfo) {
241 for (auto &Symbol : DeviceElf->symbols()) {
242 auto SymbolNameOrErr = Symbol.getName();
243 if (!SymbolNameOrErr)
244 continue;
245 auto SymbolName = *SymbolNameOrErr;
246
247 if (!(SymbolName == (GlobalName + "$ptr")))
248 continue;
249
250 Expected<uint64_t> ValueOrErr = Symbol.getValue();
251 if (!ValueOrErr)
252 reportFatalError("Expected symbol value");
253 uint64_t SymbolValue = *ValueOrErr;
254
255 // Get the section containing the symbol
256 auto SectionOrErr = Symbol.getSection();
257 if (!SectionOrErr)
258 reportFatalError("Cannot retrieve section");
259 const auto &Section = *SectionOrErr;
260 if (Section == DeviceElf->section_end())
261 reportFatalError("Expected sybmol in section");
262
263 // Get the section's address and data
264 Expected<StringRef> SectionDataOrErr = Section->getContents();
265 if (!SectionDataOrErr)
266 reportFatalError("Error retrieving section data");
267 StringRef SectionData = *SectionDataOrErr;
268
269 // Calculate offset within the section
270 uint64_t SectionAddr = Section->getAddress();
271 uint64_t Offset = SymbolValue - SectionAddr;
272 if (Offset >= SectionData.size())
273 reportFatalError("Expected offset within section size");
274
275 uint64_t *Data = (uint64_t *)(SectionData.data() + Offset);
276 if (!GVI.DevAddr)
277 reportFatalError("Cannot set global Var " + GlobalName +
278 " without a concrete device address");
279
280 *Data = reinterpret_cast<uint64_t>(GVI.DevAddr);
281 break;
282 }
283 }
284}
285
286inline void specializeIR(
287 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
288 dim3 &GridDim, ArrayRef<RuntimeConstant> RCArray,
289 const SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo,
290 bool SpecializeArgs, bool SpecializeDims, bool SpecializeDimsRange,
291 bool SpecializeLaunchBounds, int MinBlocksPerSM) {
292 TIMESCOPE("proteus::specializeIR");
293 Timer T(Config::get().ProteusEnableTimers);
294 Function *F = M.getFunction(FnName);
295
296 assert(F && "Expected non-null function!");
297 // Replace argument uses with runtime constants.
298 if (SpecializeArgs)
300
301 auto &LR = LambdaRegistry::instance();
302 for (auto &[FnName, LambdaType] : LambdaCalleeInfo) {
303 const SmallVector<RuntimeConstant> &RCVec = LR.getJitVariables(LambdaType);
304 Function *F = M.getFunction(FnName);
305 if (!F)
306 reportFatalError("Expected non-null Function");
308 }
309
310 // Run the shared array transform after any value specialization (arguments,
311 // captures) to propagate any constants.
313
314 // Replace uses of blockDim.* and gridDim.* with constants.
315 if (SpecializeDims)
316 setKernelDims(M, GridDim, BlockDim);
317 if (SpecializeDimsRange)
318 setKernelDimsRange(M, GridDim, BlockDim);
319 F->setName(FnName + Suffix);
320
321 if (SpecializeLaunchBounds) {
322 int BlockSize = BlockDim.x * BlockDim.y * BlockDim.z;
323 auto TraceOut = [](int BlockSize, int MinBlocksPerSM) {
324 SmallString<128> S;
325 raw_svector_ostream OS(S);
326 OS << "[LaunchBoundSpec] MaxThreads " << BlockSize << " MinBlocksPerSM "
327 << MinBlocksPerSM << "\n";
328
329 return S;
330 };
331 if (Config::get().traceSpecializations())
332 Logger::trace(TraceOut(BlockSize, MinBlocksPerSM));
333 setLaunchBoundsForKernel(*F, BlockSize, MinBlocksPerSM);
334 }
335
337
339 << "specializeIR " << T.elapsed() << " ms\n");
340}
341
342} // namespace proteus
343
344#endif
345
346#endif
const void const char * VarName
Definition CompilerInterfaceDevice.cpp:25
#define PROTEUS_TIMER_OUTPUT(x)
Definition Config.h:440
#define TIMESCOPE(...)
Definition TimeTracing.h:66
static Config & get()
Definition Config.h:334
static LambdaRegistry & instance()
Definition LambdaRegistry.h:21
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.h:25
static void trace(llvm::StringRef Msg)
Definition Logger.h:30
static void transform(Module &M, Function &F, ArrayRef< RuntimeConstant > RCArray)
Definition TransformArgumentSpecialization.h:88
static void transform(Module &M, Function &F, const SmallVector< RuntimeConstant > &RCVec)
Definition TransformLambdaSpecialization.h:126
static void transform(Module &M)
Definition TransformSharedArray.h:30
const SmallVector< StringRef > & threadIdxXFnName()
Definition CoreLLVMCUDA.h:70
const SmallVector< StringRef > & gridDimYFnName()
Definition CoreLLVMCUDA.h:30
const SmallVector< StringRef > & threadIdxZFnName()
Definition CoreLLVMCUDA.h:80
const SmallVector< StringRef > & blockIdxZFnName()
Definition CoreLLVMCUDA.h:65
const SmallVector< StringRef > & gridDimZFnName()
Definition CoreLLVMCUDA.h:35
const SmallVector< StringRef > & gridDimXFnName()
Definition CoreLLVMCUDA.h:25
const SmallVector< StringRef > & blockIdxXFnName()
Definition CoreLLVMCUDA.h:55
const SmallVector< StringRef > & threadIdxYFnName()
Definition CoreLLVMCUDA.h:75
const SmallVector< StringRef > & blockIdxYFnName()
Definition CoreLLVMCUDA.h:60
const SmallVector< StringRef > & blockDimYFnName()
Definition CoreLLVMCUDA.h:45
const SmallVector< StringRef > & blockDimZFnName()
Definition CoreLLVMCUDA.h:50
const SmallVector< StringRef > & blockDimXFnName()
Definition CoreLLVMCUDA.h:40
Definition MemoryCache.h:27
void setLaunchBoundsForKernel(Function &F, int MaxThreadsPerSM, int MinBlocksPerSM=0)
Definition CoreLLVMCUDA.h:87
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
static int int Offset
Definition JitInterface.h:102
std::string toString(CodegenOption Option)
Definition Config.h:28
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.h:256