1#ifndef PROTEUS_CORE_LLVM_DEVICE_HPP
2#define PROTEUS_CORE_LLVM_DEVICE_HPP
12#if defined(PROTEUS_ENABLE_HIP) || defined(PROTEUS_ENABLE_CUDA)
14#include <llvm/Analysis/CallGraph.h>
15#include <llvm/IR/ReplaceConstant.h>
16#include <llvm/IR/Verifier.h>
17#include <llvm/Object/ELFObjectFile.h>
18#include <llvm/Transforms/Utils/Cloning.h>
28inline void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
29 auto ReplaceIntrinsicDim = [&](ArrayRef<StringRef> IntrinsicNames,
31 auto CollectCallUsers = [](Function &F) {
32 SmallVector<CallInst *> CallUsers;
33 for (
auto *User : F.users()) {
34 auto *Call = dyn_cast<CallInst>(User);
37 CallUsers.push_back(Call);
43 for (
auto IntrinsicName : IntrinsicNames) {
45 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
46 if (!IntrinsicFunction)
49 for (
auto *Call : CollectCallUsers(*IntrinsicFunction)) {
50 Value *ConstantValue =
51 ConstantInt::get(Type::getInt32Ty(M.getContext()), DimValue);
52 Call->replaceAllUsesWith(ConstantValue);
53 Call->eraseFromParent();
66 auto InsertAssume = [&](ArrayRef<StringRef> IntrinsicNames,
int DimValue) {
67 for (
auto IntrinsicName : IntrinsicNames) {
68 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
69 if (!IntrinsicFunction || IntrinsicFunction->use_empty())
73 for (
auto *U : IntrinsicFunction->users()) {
74 auto *Call = dyn_cast<CallInst>(U);
79 IRBuilder<> Builder(Call->getNextNode());
80 Value *Bound = ConstantInt::get(Call->getType(), DimValue);
81 Value *Cmp = Builder.CreateICmpULT(Call, Bound);
83 Function *AssumeIntrinsic =
84 Intrinsic::getDeclaration(&M, Intrinsic::assume);
85 Builder.CreateCall(AssumeIntrinsic, Cmp);
101inline void replaceGlobalVariablesWithPointers(
103 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
106 for (
auto RegisterVar : VarNameToDevPtr) {
107 auto &
VarName = RegisterVar.first;
108 auto *GV = M.getNamedGlobal(
VarName);
115 convertUsersOfConstantsToInstructions({GV});
118 ConstantInt::get(Type::getInt64Ty(M.getContext()), 0xDEADBEEFDEADBEEF);
119 auto *CE = ConstantExpr::getIntToPtr(Addr, GV->getType()->getPointerTo());
120 auto *GVarPtr =
new GlobalVariable(
121 M, GV->getType()->getPointerTo(),
false, GlobalValue::ExternalLinkage,
122 CE, GV->getName() +
"$ptr",
nullptr, GV->getThreadLocalMode(),
123 GV->getAddressSpace(),
true);
125 SmallVector<Instruction *> ToReplace;
126 for (
auto *User : GV->users()) {
127 auto *Inst = dyn_cast<Instruction>(User);
131 ToReplace.push_back(Inst);
134 for (
auto *Inst : ToReplace) {
135 IRBuilder Builder{Inst};
136 auto *Load = Builder.CreateLoad(GV->getType(), GVarPtr);
137 Inst->replaceUsesOfWith(GV, Load);
142inline void relinkGlobalsObject(
143 MemoryBufferRef Object,
144 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
145 Expected<object::ELF64LEObjectFile> DeviceElfOrErr =
146 object::ELF64LEObjectFile::create(Object);
147 if (DeviceElfOrErr.takeError())
149 auto &DeviceElf = *DeviceElfOrErr;
151 for (
auto &[GlobalName, DevPtr] : VarNameToDevPtr) {
152 for (
auto &Symbol : DeviceElf.symbols()) {
153 auto SymbolNameOrErr = Symbol.getName();
154 if (!SymbolNameOrErr)
156 auto SymbolName = *SymbolNameOrErr;
158 if (!SymbolName.equals(GlobalName +
"$ptr"))
161 Expected<uint64_t> ValueOrErr = Symbol.getValue();
164 uint64_t SymbolValue = *ValueOrErr;
167 auto SectionOrErr = Symbol.getSection();
170 const auto &Section = *SectionOrErr;
171 if (Section == DeviceElf.section_end())
175 Expected<StringRef> SectionDataOrErr = Section->getContents();
176 if (!SectionDataOrErr)
178 StringRef SectionData = *SectionDataOrErr;
181 uint64_t SectionAddr = Section->getAddress();
182 uint64_t Offset = SymbolValue - SectionAddr;
183 if (Offset >= SectionData.size())
186 uint64_t *Data = (uint64_t *)(SectionData.data() + Offset);
187 *Data =
reinterpret_cast<uint64_t
>(DevPtr);
193inline void specializeIR(
194 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
195 dim3 &GridDim,
const SmallVector<int32_t> &
RCIndices,
196 const SmallVector<RuntimeConstant> &RCVec,
197 const SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo,
198 bool SpecializeArgs,
bool SpecializeDims,
bool SpecializeLaunchBounds) {
199 Function *F = M.getFunction(FnName);
201 assert(F &&
"Expected non-null function!");
207 for (
auto &[FnName, LambdaType] : LambdaCalleeInfo) {
208 const SmallVector<RuntimeConstant> &RCVec = LR.getJitVariables(LambdaType);
209 Function *F = M.getFunction(FnName);
221 setKernelDims(M, GridDim, BlockDim);
223 F->setName(FnName + Suffix);
225 if (SpecializeLaunchBounds)
227 BlockDim.x * BlockDim.y * BlockDim.z);
234inline std::unique_ptr<Module> cloneKernelFromModule(Module &M, LLVMContext &C,
235 const std::string &Name,
237 auto KernelModule = std::make_unique<Module>(
"JitModule", C);
238 KernelModule->setSourceFileName(M.getSourceFileName());
239 KernelModule->setDataLayout(M.getDataLayout());
240 KernelModule->setTargetTriple(M.getTargetTriple());
241 KernelModule->setModuleInlineAsm(M.getModuleInlineAsm());
242#if LLVM_VERSION_MAJOR >= 18
243 KernelModule->IsNewDbgInfoFormat = M.IsNewDbgInfoFormat;
246 auto *KernelFunction = M.getFunction(Name);
250 SmallPtrSet<Function *, 8> ReachableFunctions;
251 SmallPtrSet<GlobalVariable *, 16> ReachableGlobals;
252 SmallPtrSet<Function *, 8> ReachableDeclarations;
253 SmallVector<Function *, 8> ToVisit;
254 ReachableFunctions.insert(KernelFunction);
255 ToVisit.push_back(KernelFunction);
256 while (!ToVisit.empty()) {
257 Function *VisitF = ToVisit.pop_back_val();
258 CallGraphNode *CGNode = CG[VisitF];
260 for (
const auto &Callee : *CGNode) {
261 Function *CalleeF = Callee.second->getFunction();
264 if (CalleeF->isDeclaration()) {
265 ReachableDeclarations.insert(CalleeF);
268 if (ReachableFunctions.contains(CalleeF))
270 ReachableFunctions.insert(CalleeF);
271 ToVisit.push_back(CalleeF);
275 auto ProcessInstruction = [&](GlobalVariable &GV,
const Instruction *I) {
276 const Function *ParentF = I->getFunction();
277 if (ReachableFunctions.contains(ParentF))
278 ReachableGlobals.insert(&GV);
281 for (
auto &GV : M.globals()) {
282 for (
const User *Usr : GV.users()) {
283 const Instruction *I = dyn_cast<Instruction>(Usr);
286 ProcessInstruction(GV, I);
294 for (
const User *NextUser : Usr->users()) {
295 I = dyn_cast<Instruction>(NextUser);
299 ProcessInstruction(GV, I);
304 ValueToValueMapTy VMap;
306 for (
auto *GV : ReachableGlobals) {
308 GlobalVariable *NewGV =
309 new GlobalVariable(*KernelModule, GV->getValueType(), GV->isConstant(),
310 GV->getLinkage(),
nullptr, GV->getName(),
nullptr,
311 GV->getThreadLocalMode(), GV->getAddressSpace());
312 NewGV->copyAttributesFrom(GV);
316 for (
auto *F : ReachableFunctions) {
317 auto *NewFunction = Function::Create(F->getFunctionType(), F->getLinkage(),
318 F->getAddressSpace(), F->getName(),
320 NewFunction->copyAttributesFrom(F);
321 VMap[F] = NewFunction;
324 for (
auto *F : ReachableDeclarations) {
325 auto *NewFunction = Function::Create(F->getFunctionType(), F->getLinkage(),
326 F->getAddressSpace(), F->getName(),
328 NewFunction->copyAttributesFrom(F);
329 NewFunction->setLinkage(GlobalValue::ExternalLinkage);
330 VMap[F] = NewFunction;
333 for (GlobalVariable *GV : ReachableGlobals) {
334 if (GV->hasInitializer()) {
335 GlobalVariable *NewGV = cast<GlobalVariable>(VMap[GV]);
336 NewGV->setInitializer(MapValue(GV->getInitializer(), VMap));
340 for (
auto *F : ReachableFunctions) {
341 SmallVector<ReturnInst *, 8> Returns;
342 auto *NewFunction = dyn_cast<Function>(VMap[F]);
343 Function::arg_iterator DestI = NewFunction->arg_begin();
344 for (
const Argument &I : F->args())
345 if (VMap.count(&I) == 0) {
346 DestI->setName(I.getName());
347 VMap[&I] = &*DestI++;
349 llvm::CloneFunctionInto(NewFunction, F, VMap,
350 CloneFunctionChangeType::DifferentModule, Returns);
354 const std::string MetadataToCopy[] = {
"llvm.annotations",
"nvvm.annotations",
355 "nvvmir.version",
"llvm.module.flags"};
356 for (
auto &MetadataName : MetadataToCopy) {
357 NamedMDNode *NamedMD = M.getNamedMetadata(MetadataName);
361 auto *NewNamedMD = KernelModule->getOrInsertNamedMetadata(MetadataName);
362 for (
unsigned I = 0, E = NamedMD->getNumOperands(); I < E; ++I) {
363 MDNode *MDEntry = NamedMD->getOperand(I);
364 bool ShouldClone =
true;
367 for (
auto &Operand : MDEntry->operands()) {
368 Metadata *MD = Operand.get();
369 auto *CMD = dyn_cast<ConstantAsMetadata>(MD);
373 auto *GV = dyn_cast<GlobalValue>(CMD->getValue());
377 if (!VMap.count(GV)) {
386 NewNamedMD->addOperand(MapMetadata(MDEntry, VMap));
390#if PROTEUS_ENABLE_DEBUG
392 if (verifyModule(*KernelModule, &errs()))
const char * VarName
Definition CompilerInterfaceDevice.cpp:20
void char int32_t * RCIndices
Definition CompilerInterfaceDevice.cpp:51
#define PROTEUS_DBG(x)
Definition Debug.h:7
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:21
static void logfile(const std::string &Filename, T &&Data)
Definition Logger.hpp:24
const SmallVector< StringRef > & threadIdxXFnName()
Definition CoreLLVMCUDA.hpp:68
const SmallVector< StringRef > & gridDimYFnName()
Definition CoreLLVMCUDA.hpp:28
const SmallVector< StringRef > & threadIdxZFnName()
Definition CoreLLVMCUDA.hpp:78
const SmallVector< StringRef > & blockIdxZFnName()
Definition CoreLLVMCUDA.hpp:63
const SmallVector< StringRef > & gridDimZFnName()
Definition CoreLLVMCUDA.hpp:33
const SmallVector< StringRef > & gridDimXFnName()
Definition CoreLLVMCUDA.hpp:23
const SmallVector< StringRef > & blockIdxXFnName()
Definition CoreLLVMCUDA.hpp:53
const SmallVector< StringRef > & threadIdxYFnName()
Definition CoreLLVMCUDA.hpp:73
const SmallVector< StringRef > & blockIdxYFnName()
Definition CoreLLVMCUDA.hpp:58
const SmallVector< StringRef > & blockDimYFnName()
Definition CoreLLVMCUDA.hpp:43
const SmallVector< StringRef > & blockDimZFnName()
Definition CoreLLVMCUDA.hpp:48
const SmallVector< StringRef > & blockDimXFnName()
Definition CoreLLVMCUDA.hpp:38
Definition JitEngine.cpp:20
void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize, int BlockSize)
Definition CoreLLVMCUDA.hpp:85
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.hpp:170