Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CoreLLVMDevice.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_CORE_LLVM_DEVICE_HPP
2#define PROTEUS_CORE_LLVM_DEVICE_HPP
3
4#if PROTEUS_ENABLE_HIP
6#endif
7
8#if PROTEUS_ENABLE_CUDA
10#endif
11
12#if defined(PROTEUS_ENABLE_HIP) || defined(PROTEUS_ENABLE_CUDA)
13
14#include <llvm/Analysis/CallGraph.h>
15#include <llvm/IR/ReplaceConstant.h>
16#include <llvm/IR/Verifier.h>
17#include <llvm/Object/ELFObjectFile.h>
18#include <llvm/Transforms/Utils/Cloning.h>
19
25
26namespace proteus {
27
28inline void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
29 auto ReplaceIntrinsicDim = [&](ArrayRef<StringRef> IntrinsicNames,
30 uint32_t DimValue) {
31 auto CollectCallUsers = [](Function &F) {
32 SmallVector<CallInst *> CallUsers;
33 for (auto *User : F.users()) {
34 auto *Call = dyn_cast<CallInst>(User);
35 if (!Call)
36 continue;
37 CallUsers.push_back(Call);
38 }
39
40 return CallUsers;
41 };
42
43 for (auto IntrinsicName : IntrinsicNames) {
44
45 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
46 if (!IntrinsicFunction)
47 continue;
48
49 for (auto *Call : CollectCallUsers(*IntrinsicFunction)) {
50 Value *ConstantValue =
51 ConstantInt::get(Type::getInt32Ty(M.getContext()), DimValue);
52 Call->replaceAllUsesWith(ConstantValue);
53 Call->eraseFromParent();
54 }
55 }
56 };
57
58 ReplaceIntrinsicDim(detail::gridDimXFnName(), GridDim.x);
59 ReplaceIntrinsicDim(detail::gridDimYFnName(), GridDim.y);
60 ReplaceIntrinsicDim(detail::gridDimZFnName(), GridDim.z);
61
62 ReplaceIntrinsicDim(detail::blockDimXFnName(), BlockDim.x);
63 ReplaceIntrinsicDim(detail::blockDimYFnName(), BlockDim.y);
64 ReplaceIntrinsicDim(detail::blockDimZFnName(), BlockDim.z);
65
66 auto InsertAssume = [&](ArrayRef<StringRef> IntrinsicNames, int DimValue) {
67 for (auto IntrinsicName : IntrinsicNames) {
68 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
69 if (!IntrinsicFunction || IntrinsicFunction->use_empty())
70 continue;
71
72 // Iterate over all uses of the intrinsic.
73 for (auto *U : IntrinsicFunction->users()) {
74 auto *Call = dyn_cast<CallInst>(U);
75 if (!Call)
76 continue;
77
78 // Insert the llvm.assume intrinsic.
79 IRBuilder<> Builder(Call->getNextNode());
80 Value *Bound = ConstantInt::get(Call->getType(), DimValue);
81 Value *Cmp = Builder.CreateICmpULT(Call, Bound);
82
83 Function *AssumeIntrinsic =
84 Intrinsic::getDeclaration(&M, Intrinsic::assume);
85 Builder.CreateCall(AssumeIntrinsic, Cmp);
86 }
87 }
88 };
89
90 // Inform LLVM about the range of possible values of threadIdx.*.
91 InsertAssume(detail::threadIdxXFnName(), BlockDim.x);
92 InsertAssume(detail::threadIdxYFnName(), BlockDim.y);
93 InsertAssume(detail::threadIdxZFnName(), BlockDim.z);
94
95 // Inform LLVdetailut the range of possible values of blockIdx.*.
96 InsertAssume(detail::blockIdxXFnName(), GridDim.x);
97 InsertAssume(detail::blockIdxYFnName(), GridDim.y);
98 InsertAssume(detail::blockIdxZFnName(), GridDim.z);
99}
100
101inline void replaceGlobalVariablesWithPointers(
102 Module &M,
103 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
104 // Re-link globals to fixed addresses provided by registered
105 // variables.
106 for (auto RegisterVar : VarNameToDevPtr) {
107 auto &VarName = RegisterVar.first;
108 auto *GV = M.getNamedGlobal(VarName);
109 // Skip linking if the GV does not exist in the module.
110 if (!GV)
111 continue;
112
113 // This will convert constant users of GV to instructions so that we can
114 // replace with the GV ptr.
115 convertUsersOfConstantsToInstructions({GV});
116
117 Constant *Addr =
118 ConstantInt::get(Type::getInt64Ty(M.getContext()), 0xDEADBEEFDEADBEEF);
119 auto *CE = ConstantExpr::getIntToPtr(Addr, GV->getType()->getPointerTo());
120 auto *GVarPtr = new GlobalVariable(
121 M, GV->getType()->getPointerTo(), false, GlobalValue::ExternalLinkage,
122 CE, GV->getName() + "$ptr", nullptr, GV->getThreadLocalMode(),
123 GV->getAddressSpace(), true);
124
125 SmallVector<Instruction *> ToReplace;
126 for (auto *User : GV->users()) {
127 auto *Inst = dyn_cast<Instruction>(User);
128 if (!Inst)
129 PROTEUS_FATAL_ERROR("Expected Instruction User for GV");
130
131 ToReplace.push_back(Inst);
132 }
133
134 for (auto *Inst : ToReplace) {
135 IRBuilder Builder{Inst};
136 auto *Load = Builder.CreateLoad(GV->getType(), GVarPtr);
137 Inst->replaceUsesOfWith(GV, Load);
138 }
139 }
140}
141
142inline void relinkGlobalsObject(
143 MemoryBufferRef Object,
144 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
145 Expected<object::ELF64LEObjectFile> DeviceElfOrErr =
146 object::ELF64LEObjectFile::create(Object);
147 if (DeviceElfOrErr.takeError())
148 PROTEUS_FATAL_ERROR("Cannot create the device elf");
149 auto &DeviceElf = *DeviceElfOrErr;
150
151 for (auto &[GlobalName, DevPtr] : VarNameToDevPtr) {
152 for (auto &Symbol : DeviceElf.symbols()) {
153 auto SymbolNameOrErr = Symbol.getName();
154 if (!SymbolNameOrErr)
155 continue;
156 auto SymbolName = *SymbolNameOrErr;
157
158 if (!SymbolName.equals(GlobalName + "$ptr"))
159 continue;
160
161 Expected<uint64_t> ValueOrErr = Symbol.getValue();
162 if (!ValueOrErr)
163 PROTEUS_FATAL_ERROR("Expected symbol value");
164 uint64_t SymbolValue = *ValueOrErr;
165
166 // Get the section containing the symbol
167 auto SectionOrErr = Symbol.getSection();
168 if (!SectionOrErr)
169 PROTEUS_FATAL_ERROR("Cannot retrieve section");
170 const auto &Section = *SectionOrErr;
171 if (Section == DeviceElf.section_end())
172 PROTEUS_FATAL_ERROR("Expected sybmol in section");
173
174 // Get the section's address and data
175 Expected<StringRef> SectionDataOrErr = Section->getContents();
176 if (!SectionDataOrErr)
177 PROTEUS_FATAL_ERROR("Error retrieving section data");
178 StringRef SectionData = *SectionDataOrErr;
179
180 // Calculate offset within the section
181 uint64_t SectionAddr = Section->getAddress();
182 uint64_t Offset = SymbolValue - SectionAddr;
183 if (Offset >= SectionData.size())
184 PROTEUS_FATAL_ERROR("Expected offset within section size");
185
186 uint64_t *Data = (uint64_t *)(SectionData.data() + Offset);
187 *Data = reinterpret_cast<uint64_t>(DevPtr);
188 break;
189 }
190 }
191}
192
193inline void specializeIR(
194 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
195 dim3 &GridDim, const SmallVector<int32_t> &RCIndices,
196 const SmallVector<RuntimeConstant> &RCVec,
197 const SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo,
198 bool SpecializeArgs, bool SpecializeDims, bool SpecializeLaunchBounds) {
199 Function *F = M.getFunction(FnName);
200
201 assert(F && "Expected non-null function!");
202 // Replace argument uses with runtime constants.
203 if (SpecializeArgs)
205
206 auto &LR = LambdaRegistry::instance();
207 for (auto &[FnName, LambdaType] : LambdaCalleeInfo) {
208 const SmallVector<RuntimeConstant> &RCVec = LR.getJitVariables(LambdaType);
209 Function *F = M.getFunction(FnName);
210 if (!F)
211 PROTEUS_FATAL_ERROR("Expected non-null Function");
213 }
214
215 // Run the shared array transform after any value specialization (arguments,
216 // captures) to propagate any constants.
218
219 // Replace uses of blockDim.* and gridDim.* with constants.
220 if (SpecializeDims)
221 setKernelDims(M, GridDim, BlockDim);
222
223 F->setName(FnName + Suffix);
224
225 if (SpecializeLaunchBounds)
226 setLaunchBoundsForKernel(M, *F, GridDim.x * GridDim.y * GridDim.z,
227 BlockDim.x * BlockDim.y * BlockDim.z);
228
230
231 PROTEUS_DBG(Logger::logfile(FnName.str() + ".specialized.ll", M));
232}
233
234inline std::unique_ptr<Module> cloneKernelFromModule(Module &M, LLVMContext &C,
235 const std::string &Name,
236 CallGraph &CG) {
237 auto KernelModule = std::make_unique<Module>("JitModule", C);
238 KernelModule->setSourceFileName(M.getSourceFileName());
239 KernelModule->setDataLayout(M.getDataLayout());
240 KernelModule->setTargetTriple(M.getTargetTriple());
241 KernelModule->setModuleInlineAsm(M.getModuleInlineAsm());
242#if LLVM_VERSION_MAJOR >= 18
243 KernelModule->IsNewDbgInfoFormat = M.IsNewDbgInfoFormat;
244#endif
245
246 auto *KernelFunction = M.getFunction(Name);
247 if (!KernelFunction)
248 PROTEUS_FATAL_ERROR("Expected function " + Name);
249
250 SmallPtrSet<Function *, 8> ReachableFunctions;
251 SmallPtrSet<GlobalVariable *, 16> ReachableGlobals;
252 SmallPtrSet<Function *, 8> ReachableDeclarations;
253 SmallVector<Function *, 8> ToVisit;
254 ReachableFunctions.insert(KernelFunction);
255 ToVisit.push_back(KernelFunction);
256 while (!ToVisit.empty()) {
257 Function *VisitF = ToVisit.pop_back_val();
258 CallGraphNode *CGNode = CG[VisitF];
259
260 for (const auto &Callee : *CGNode) {
261 Function *CalleeF = Callee.second->getFunction();
262 if (!CalleeF)
263 continue;
264 if (CalleeF->isDeclaration()) {
265 ReachableDeclarations.insert(CalleeF);
266 continue;
267 }
268 if (ReachableFunctions.contains(CalleeF))
269 continue;
270 ReachableFunctions.insert(CalleeF);
271 ToVisit.push_back(CalleeF);
272 }
273 }
274
275 auto ProcessInstruction = [&](GlobalVariable &GV, const Instruction *I) {
276 const Function *ParentF = I->getFunction();
277 if (ReachableFunctions.contains(ParentF))
278 ReachableGlobals.insert(&GV);
279 };
280
281 for (auto &GV : M.globals()) {
282 for (const User *Usr : GV.users()) {
283 const Instruction *I = dyn_cast<Instruction>(Usr);
284
285 if (I) {
286 ProcessInstruction(GV, I);
287 continue;
288 }
289
290 // We follow non-instructions users to process them if those are
291 // instructions.
292 // TODO: We may need to follow deeper than just users of user and also
293 // expand to non-instruction users.
294 for (const User *NextUser : Usr->users()) {
295 I = dyn_cast<Instruction>(NextUser);
296 if (!I)
297 continue;
298
299 ProcessInstruction(GV, I);
300 }
301 }
302 }
303
304 ValueToValueMapTy VMap;
305
306 for (auto *GV : ReachableGlobals) {
307 // We will set the initializer later, after VMap has been populated.
308 GlobalVariable *NewGV =
309 new GlobalVariable(*KernelModule, GV->getValueType(), GV->isConstant(),
310 GV->getLinkage(), nullptr, GV->getName(), nullptr,
311 GV->getThreadLocalMode(), GV->getAddressSpace());
312 NewGV->copyAttributesFrom(GV);
313 VMap[GV] = NewGV;
314 }
315
316 for (auto *F : ReachableFunctions) {
317 auto *NewFunction = Function::Create(F->getFunctionType(), F->getLinkage(),
318 F->getAddressSpace(), F->getName(),
319 KernelModule.get());
320 NewFunction->copyAttributesFrom(F);
321 VMap[F] = NewFunction;
322 }
323
324 for (auto *F : ReachableDeclarations) {
325 auto *NewFunction = Function::Create(F->getFunctionType(), F->getLinkage(),
326 F->getAddressSpace(), F->getName(),
327 KernelModule.get());
328 NewFunction->copyAttributesFrom(F);
329 NewFunction->setLinkage(GlobalValue::ExternalLinkage);
330 VMap[F] = NewFunction;
331 }
332
333 for (GlobalVariable *GV : ReachableGlobals) {
334 if (GV->hasInitializer()) {
335 GlobalVariable *NewGV = cast<GlobalVariable>(VMap[GV]);
336 NewGV->setInitializer(MapValue(GV->getInitializer(), VMap));
337 }
338 }
339
340 for (auto *F : ReachableFunctions) {
341 SmallVector<ReturnInst *, 8> Returns;
342 auto *NewFunction = dyn_cast<Function>(VMap[F]);
343 Function::arg_iterator DestI = NewFunction->arg_begin();
344 for (const Argument &I : F->args())
345 if (VMap.count(&I) == 0) {
346 DestI->setName(I.getName());
347 VMap[&I] = &*DestI++;
348 }
349 llvm::CloneFunctionInto(NewFunction, F, VMap,
350 CloneFunctionChangeType::DifferentModule, Returns);
351 }
352
353 // Copy annotations from M into KernelModule now that VMap has been populated.
354 const std::string MetadataToCopy[] = {"llvm.annotations", "nvvm.annotations",
355 "nvvmir.version", "llvm.module.flags"};
356 for (auto &MetadataName : MetadataToCopy) {
357 NamedMDNode *NamedMD = M.getNamedMetadata(MetadataName);
358 if (!NamedMD)
359 continue;
360
361 auto *NewNamedMD = KernelModule->getOrInsertNamedMetadata(MetadataName);
362 for (unsigned I = 0, E = NamedMD->getNumOperands(); I < E; ++I) {
363 MDNode *MDEntry = NamedMD->getOperand(I);
364 bool ShouldClone = true;
365 // Skip if the operands of an MDNode refer to non-existing,
366 // unreachable global values.
367 for (auto &Operand : MDEntry->operands()) {
368 Metadata *MD = Operand.get();
369 auto *CMD = dyn_cast<ConstantAsMetadata>(MD);
370 if (!CMD)
371 continue;
372
373 auto *GV = dyn_cast<GlobalValue>(CMD->getValue());
374 if (!GV)
375 continue;
376
377 if (!VMap.count(GV)) {
378 ShouldClone = false;
379 break;
380 }
381 }
382
383 if (!ShouldClone)
384 continue;
385
386 NewNamedMD->addOperand(MapMetadata(MDEntry, VMap));
387 }
388 }
389
390#if PROTEUS_ENABLE_DEBUG
391 Logger::logfile(Name + ".mini.ll", *KernelModule);
392 if (verifyModule(*KernelModule, &errs()))
393 PROTEUS_FATAL_ERROR("Broken mini-module found, JIT compilation aborted!");
394#endif
395
396 return KernelModule;
397}
398
399} // namespace proteus
400
401#endif
402
403#endif
const char * VarName
Definition CompilerInterfaceDevice.cpp:20
void char int32_t * RCIndices
Definition CompilerInterfaceDevice.cpp:51
#define PROTEUS_DBG(x)
Definition Debug.h:7
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:21
static void logfile(const std::string &Filename, T &&Data)
Definition Logger.hpp:24
static void transform(Module &M, Function &F, const SmallVectorImpl< int32_t > &ArgPos, ArrayRef< RuntimeConstant > RC)
Definition TransformArgumentSpecialization.hpp:26
static void transform(Module &M, Function &F, const SmallVector< RuntimeConstant > &RCVec)
Definition TransformLambdaSpecialization.hpp:59
static void transform(Module &M)
Definition TransformSharedArray.hpp:30
const SmallVector< StringRef > & threadIdxXFnName()
Definition CoreLLVMCUDA.hpp:68
const SmallVector< StringRef > & gridDimYFnName()
Definition CoreLLVMCUDA.hpp:28
const SmallVector< StringRef > & threadIdxZFnName()
Definition CoreLLVMCUDA.hpp:78
const SmallVector< StringRef > & blockIdxZFnName()
Definition CoreLLVMCUDA.hpp:63
const SmallVector< StringRef > & gridDimZFnName()
Definition CoreLLVMCUDA.hpp:33
const SmallVector< StringRef > & gridDimXFnName()
Definition CoreLLVMCUDA.hpp:23
const SmallVector< StringRef > & blockIdxXFnName()
Definition CoreLLVMCUDA.hpp:53
const SmallVector< StringRef > & threadIdxYFnName()
Definition CoreLLVMCUDA.hpp:73
const SmallVector< StringRef > & blockIdxYFnName()
Definition CoreLLVMCUDA.hpp:58
const SmallVector< StringRef > & blockDimYFnName()
Definition CoreLLVMCUDA.hpp:43
const SmallVector< StringRef > & blockDimZFnName()
Definition CoreLLVMCUDA.hpp:48
const SmallVector< StringRef > & blockDimXFnName()
Definition CoreLLVMCUDA.hpp:38
Definition JitEngine.cpp:20
void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize, int BlockSize)
Definition CoreLLVMCUDA.hpp:85
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.hpp:170