1#ifndef PROTEUS_CORE_LLVM_DEVICE_H
2#define PROTEUS_CORE_LLVM_DEVICE_H
12#if defined(PROTEUS_ENABLE_HIP) || defined(PROTEUS_ENABLE_CUDA)
21#include <llvm/Analysis/CallGraph.h>
22#include <llvm/Bitcode/BitcodeReader.h>
23#include <llvm/Bitcode/BitcodeWriter.h>
24#include <llvm/IR/Attributes.h>
25#include <llvm/IR/ConstantRange.h>
26#include <llvm/IR/ReplaceConstant.h>
27#include <llvm/IR/Verifier.h>
28#include <llvm/Object/ELFObjectFile.h>
29#include <llvm/Transforms/Utils/Cloning.h>
33inline void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
34 auto ReplaceIntrinsicDim = [&](ArrayRef<StringRef> IntrinsicNames,
36 auto CollectCallUsers = [](Function &F) {
37 SmallVector<CallInst *> CallUsers;
38 for (
auto *User : F.users()) {
39 auto *Call = dyn_cast<CallInst>(User);
42 CallUsers.push_back(Call);
48 for (
auto IntrinsicName : IntrinsicNames) {
50 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
51 if (!IntrinsicFunction)
54 auto TraceOut = [](Function *F, Value *C) {
56 raw_svector_ostream OS(S);
57 OS <<
"[DimSpec] Replace call to " << F->getName() <<
" with constant "
63 for (
auto *Call : CollectCallUsers(*IntrinsicFunction)) {
64 Value *ConstantValue =
65 ConstantInt::get(Type::getInt32Ty(M.getContext()), DimValue);
66 Call->replaceAllUsesWith(ConstantValue);
69 Call->eraseFromParent();
83inline void setKernelDimsRange(Module &M, dim3 &GridDim, dim3 &BlockDim) {
84 auto AttachRange = [&](ArrayRef<StringRef> IntrinsicNames,
90 for (
auto IntrinsicName : IntrinsicNames) {
91 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
92 if (!IntrinsicFunction || IntrinsicFunction->use_empty())
95 auto TraceOut = [](Function *IntrinsicF, uint32_t DimValue) {
97 raw_svector_ostream OS(S);
98 OS <<
"[DimSpec] Range " << IntrinsicF->getName() <<
" [0," << DimValue
103 for (
auto *U : IntrinsicFunction->users()) {
104 auto *Call = dyn_cast<CallInst>(U);
108 auto *RetTy = dyn_cast<IntegerType>(Call->getType());
112 unsigned BitWidth = RetTy->getBitWidth();
113 ConstantRange Range(APInt(BitWidth, 0), APInt(BitWidth, DimValue));
115#if LLVM_VERSION_MAJOR >= 19
116 AttrBuilder Builder{M.getContext()};
117 Builder.addRangeAttr(Range);
118 Call->removeRetAttr(Attribute::Range);
120 Call->getAttributes().addRetAttributes(M.getContext(), Builder));
124 LLVMContext &Ctx = M.getContext();
125 Metadata *RangeMD[] = {
126 ConstantAsMetadata::get(ConstantInt::get(RetTy, 0)),
127 ConstantAsMetadata::get(ConstantInt::get(RetTy, DimValue))};
128 MDNode *RangeNode = MDNode::get(Ctx, RangeMD);
129 Call->setMetadata(LLVMContext::MD_range, RangeNode);
147inline void replaceGlobalVariablesWithPointers(
149 const std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo) {
152 for (
auto RegisterVar : VarNameToGlobalInfo) {
153 auto &
VarName = RegisterVar.first;
154 auto *GV = M.getNamedGlobal(
VarName);
161 convertUsersOfConstantsToInstructions({GV});
164 ConstantInt::get(Type::getInt64Ty(M.getContext()), 0xDEADBEEFDEADBEEF);
165 auto *CE = ConstantExpr::getIntToPtr(
166 Addr, PointerType::get(GV->getType(), GV->getAddressSpace()));
167 auto *GVarPtr =
new GlobalVariable(
168 M, PointerType::get(GV->getType(), GV->getAddressSpace()),
false,
169 GlobalValue::ExternalLinkage, CE, GV->getName() +
"$ptr",
nullptr,
170 GV->getThreadLocalMode(), GV->getAddressSpace(),
true);
173 SmallPtrSet<Value *, 16> ValuesToReplace;
174 SmallVector<Value *> Worklist;
176 Worklist.push_back(GV);
177 ValuesToReplace.insert(GV);
178 while (!Worklist.empty()) {
179 Value *V = Worklist.pop_back_val();
180 for (
auto *User : V->users()) {
181 if (
auto *C = dyn_cast<Constant>(User)) {
182 if (ValuesToReplace.insert(C).second)
183 Worklist.push_back(C);
189 if (isa<Instruction>(User))
197 for (Value *V : ValuesToReplace) {
198 SmallPtrSet<Instruction *, 16> Insts;
200 for (User *U : V->users()) {
201 if (
auto *I = dyn_cast<Instruction>(U)) {
207 for (
auto *I : Insts) {
208 IRBuilder Builder{I};
209 auto *Load = Builder.CreateLoad(GV->getType(), GVarPtr);
210 Value *Replacement = Load;
211 Type *ExpectedTy = V->getType();
212 if (Load->getType() != ExpectedTy)
214 Builder.CreatePointerBitCastOrAddrSpaceCast(Load, ExpectedTy);
216 I->replaceUsesOfWith(V, Replacement);
222 if (verifyModule(M, &errs()))
227inline void relinkGlobalsObject(
228 MemoryBufferRef Object,
229 const std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo) {
230 Expected<object::ELF64LEObjectFile> DeviceElfOrErr =
231 object::ELF64LEObjectFile::create(Object);
232 if (
auto E = DeviceElfOrErr.takeError())
234 auto &DeviceElf = *DeviceElfOrErr;
236 for (
auto &[GlobalName, GVI] : VarNameToGlobalInfo) {
237 for (
auto &Symbol : DeviceElf.symbols()) {
238 auto SymbolNameOrErr = Symbol.getName();
239 if (!SymbolNameOrErr)
241 auto SymbolName = *SymbolNameOrErr;
243 if (!(SymbolName == (GlobalName +
"$ptr")))
246 Expected<uint64_t> ValueOrErr = Symbol.getValue();
249 uint64_t SymbolValue = *ValueOrErr;
252 auto SectionOrErr = Symbol.getSection();
255 const auto &Section = *SectionOrErr;
256 if (Section == DeviceElf.section_end())
260 Expected<StringRef> SectionDataOrErr = Section->getContents();
261 if (!SectionDataOrErr)
263 StringRef SectionData = *SectionDataOrErr;
266 uint64_t SectionAddr = Section->getAddress();
267 uint64_t
Offset = SymbolValue - SectionAddr;
268 if (
Offset >= SectionData.size())
271 uint64_t *Data = (uint64_t *)(SectionData.data() +
Offset);
274 " without a concrete device address");
276 *Data =
reinterpret_cast<uint64_t
>(GVI.DevAddr);
282inline void specializeIR(
283 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
284 dim3 &GridDim, ArrayRef<RuntimeConstant> RCArray,
285 const SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo,
286 bool SpecializeArgs,
bool SpecializeDims,
bool SpecializeDimsRange,
287 bool SpecializeLaunchBounds,
int MinBlocksPerSM) {
289 Function *F = M.getFunction(FnName);
291 assert(F &&
"Expected non-null function!");
297 for (
auto &[FnName, LambdaType] : LambdaCalleeInfo) {
298 const SmallVector<RuntimeConstant> &RCVec = LR.getJitVariables(LambdaType);
299 Function *F = M.getFunction(FnName);
311 setKernelDims(M, GridDim, BlockDim);
312 if (SpecializeDimsRange)
313 setKernelDimsRange(M, GridDim, BlockDim);
314 F->setName(FnName + Suffix);
316 if (SpecializeLaunchBounds) {
317 int BlockSize = BlockDim.x * BlockDim.y * BlockDim.z;
318 auto TraceOut = [](
int BlockSize,
int MinBlocksPerSM) {
320 raw_svector_ostream OS(S);
321 OS <<
"[LaunchBoundSpec] MaxThreads " << BlockSize <<
" MinBlocksPerSM "
322 << MinBlocksPerSM <<
"\n";
334 <<
"specializeIR " << T.elapsed() <<
" ms\n");
const void const char * VarName
Definition CompilerInterfaceDevice.cpp:24
#define PROTEUS_TIMER_OUTPUT(x)
Definition TimeTracing.h:54
static Config & get()
Definition Config.h:334
static LambdaRegistry & instance()
Definition LambdaRegistry.h:21
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.h:25
static void trace(llvm::StringRef Msg)
Definition Logger.h:30
const SmallVector< StringRef > & threadIdxXFnName()
Definition CoreLLVMCUDA.h:70
const SmallVector< StringRef > & gridDimYFnName()
Definition CoreLLVMCUDA.h:30
const SmallVector< StringRef > & threadIdxZFnName()
Definition CoreLLVMCUDA.h:80
const SmallVector< StringRef > & blockIdxZFnName()
Definition CoreLLVMCUDA.h:65
const SmallVector< StringRef > & gridDimZFnName()
Definition CoreLLVMCUDA.h:35
const SmallVector< StringRef > & gridDimXFnName()
Definition CoreLLVMCUDA.h:25
const SmallVector< StringRef > & blockIdxXFnName()
Definition CoreLLVMCUDA.h:55
const SmallVector< StringRef > & threadIdxYFnName()
Definition CoreLLVMCUDA.h:75
const SmallVector< StringRef > & blockIdxYFnName()
Definition CoreLLVMCUDA.h:60
const SmallVector< StringRef > & blockDimYFnName()
Definition CoreLLVMCUDA.h:45
const SmallVector< StringRef > & blockDimZFnName()
Definition CoreLLVMCUDA.h:50
const SmallVector< StringRef > & blockDimXFnName()
Definition CoreLLVMCUDA.h:40
Definition MemoryCache.h:26
void setLaunchBoundsForKernel(Function &F, int MaxThreadsPerSM, int MinBlocksPerSM=0)
Definition CoreLLVMCUDA.h:87
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
static int int Offset
Definition JitInterface.h:102
std::string toString(CodegenOption Option)
Definition Config.h:28
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.h:230