Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CoreLLVMDevice.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_CORE_LLVM_DEVICE_HPP
2#define PROTEUS_CORE_LLVM_DEVICE_HPP
3
4#if PROTEUS_ENABLE_HIP
6#endif
7
8#if PROTEUS_ENABLE_CUDA
10#endif
11
12#if defined(PROTEUS_ENABLE_HIP) || defined(PROTEUS_ENABLE_CUDA)
13
14#include <llvm/Analysis/CallGraph.h>
15#include <llvm/Bitcode/BitcodeReader.h>
16#include <llvm/Bitcode/BitcodeWriter.h>
17#include <llvm/IR/Attributes.h>
18#include <llvm/IR/ConstantRange.h>
19#include <llvm/IR/ReplaceConstant.h>
20#include <llvm/IR/Verifier.h>
21#include <llvm/Object/ELFObjectFile.h>
22#include <llvm/Transforms/Utils/Cloning.h>
23
30
31namespace proteus {
32
33inline void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
36 auto CollectCallUsers = [](Function &F) {
38 for (auto *User : F.users()) {
40 if (!Call)
41 continue;
42 CallUsers.push_back(Call);
43 }
44
45 return CallUsers;
46 };
47
48 for (auto IntrinsicName : IntrinsicNames) {
49
52 continue;
53
54 auto TraceOut = [](Function *F, Value *C) {
57 OS << "[DimSpec] Replace call to " << F->getName() << " with constant "
58 << *C << "\n";
59
60 return S;
61 };
62
64 Value *ConstantValue =
65 ConstantInt::get(Type::getInt32Ty(M.getContext()), DimValue);
66 Call->replaceAllUsesWith(ConstantValue);
67 if (Config::get().ProteusTraceOutput >= 1)
69 Call->eraseFromParent();
70 }
71 }
72 };
73
77
81}
82
83inline void setKernelDimsRange(Module &M, dim3 &GridDim, dim3 &BlockDim) {
86 if (DimValue == 0) {
87 PROTEUS_FATAL_ERROR("Dimension value cannot be zero");
88 }
89
90 for (auto IntrinsicName : IntrinsicNames) {
92 if (!IntrinsicFunction || IntrinsicFunction->use_empty())
93 continue;
94
98 OS << "[DimSpec] Range " << IntrinsicF->getName() << " [0," << DimValue
99 << ")\n";
100 return S;
101 };
102
103 for (auto *U : IntrinsicFunction->users()) {
104 auto *Call = dyn_cast<CallInst>(U);
105 if (!Call)
106 continue;
107
108 auto *RetTy = dyn_cast<IntegerType>(Call->getType());
109 if (!RetTy)
110 continue;
111
112 unsigned BitWidth = RetTy->getBitWidth();
114
115#if LLVM_VERSION_MAJOR >= 19
116 AttrBuilder Builder{M.getContext()};
117 Builder.addRangeAttr(Range);
118 Call->removeRetAttr(Attribute::Range);
119 Call->setAttributes(
120 Call->getAttributes().addRetAttributes(M.getContext(), Builder));
121#else
122 // LLVM 18 (ROCm 6.2.x) does not expose the Range attribute; use range
123 // metadata instead.
124 LLVMContext &Ctx = M.getContext();
125 Metadata *RangeMD[] = {
126 ConstantAsMetadata::get(ConstantInt::get(RetTy, 0)),
127 ConstantAsMetadata::get(ConstantInt::get(RetTy, DimValue))};
128 MDNode *RangeNode = MDNode::get(Ctx, RangeMD);
129 Call->setMetadata(LLVMContext::MD_range, RangeNode);
130#endif
131
132 if (Config::get().ProteusTraceOutput >= 1)
134 }
135 }
136 };
137
141
145}
146
148 Module &M,
149 const std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo) {
150 // Re-link globals to fixed addresses provided by registered
151 // variables.
152 for (auto RegisterVar : VarNameToGlobalInfo) {
153 auto &VarName = RegisterVar.first;
154 auto *GV = M.getNamedGlobal(VarName);
155 // Skip linking if the GV does not exist in the module.
156 if (!GV)
157 continue;
158
159 // This will convert constant users of GV to instructions so that we can
160 // replace with the GV ptr.
162
163 Constant *Addr =
164 ConstantInt::get(Type::getInt64Ty(M.getContext()), 0xDEADBEEFDEADBEEF);
165 auto *CE = ConstantExpr::getIntToPtr(Addr, GV->getType()->getPointerTo());
166 auto *GVarPtr = new GlobalVariable(
167 M, GV->getType()->getPointerTo(), false, GlobalValue::ExternalLinkage,
168 CE, GV->getName() + "$ptr", nullptr, GV->getThreadLocalMode(),
169 GV->getAddressSpace(), true);
170
171 // Find all Constant users that refer to the global variable.
173 SmallVector<Value *> Worklist;
174 // Seed with the global variable.
175 Worklist.push_back(GV);
176 ValuesToReplace.insert(GV);
177 while (!Worklist.empty()) {
178 Value *V = Worklist.pop_back_val();
179 for (auto *User : V->users()) {
180 if (auto *C = dyn_cast<Constant>(User)) {
181 if (ValuesToReplace.insert(C).second)
182 Worklist.push_back(C);
183
184 continue;
185 }
186
187 // Skip instructions to be handled when replacing.
189 continue;
190
192 "Expected Instruction or Constant user for Value: " + toString(*V) +
193 " , User: " + toString(*User));
194 }
195 }
196
197 for (Value *V : ValuesToReplace) {
199 // Find instruction users to replace value.
200 for (User *U : V->users()) {
201 if (auto *I = dyn_cast<Instruction>(U)) {
202 Insts.insert(I);
203 }
204 }
205
206 // Replace value in instructions.
207 for (auto *I : Insts) {
209 auto *Load = Builder.CreateLoad(GV->getType(), GVarPtr);
210 Value *Replacement = Load;
211 Type *ExpectedTy = V->getType();
212 if (Load->getType() != ExpectedTy)
214 Builder.CreatePointerBitCastOrAddrSpaceCast(Load, ExpectedTy);
215
216 I->replaceUsesOfWith(V, Replacement);
217 }
218 }
219 }
220
221 if (Config::get().ProteusDebugOutput) {
222 if (verifyModule(M, &errs()))
223 PROTEUS_FATAL_ERROR("Broken module found, JIT compilation aborted!");
224 }
225}
226
227inline void relinkGlobalsObject(
229 const std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo) {
231 object::ELF64LEObjectFile::create(Object);
232 if (auto E = DeviceElfOrErr.takeError())
233 PROTEUS_FATAL_ERROR("Cannot create the device elf: " +
234 toString(std::move(E)));
235 auto &DeviceElf = *DeviceElfOrErr;
236
237 for (auto &[GlobalName, GVI] : VarNameToGlobalInfo) {
238 for (auto &Symbol : DeviceElf.symbols()) {
239 auto SymbolNameOrErr = Symbol.getName();
240 if (!SymbolNameOrErr)
241 continue;
243
244 if (!(SymbolName == (GlobalName + "$ptr")))
245 continue;
246
248 if (!ValueOrErr)
249 PROTEUS_FATAL_ERROR("Expected symbol value");
251
252 // Get the section containing the symbol
253 auto SectionOrErr = Symbol.getSection();
254 if (!SectionOrErr)
255 PROTEUS_FATAL_ERROR("Cannot retrieve section");
256 const auto &Section = *SectionOrErr;
257 if (Section == DeviceElf.section_end())
258 PROTEUS_FATAL_ERROR("Expected sybmol in section");
259
260 // Get the section's address and data
262 if (!SectionDataOrErr)
263 PROTEUS_FATAL_ERROR("Error retrieving section data");
265
266 // Calculate offset within the section
267 uint64_t SectionAddr = Section->getAddress();
269 if (Offset >= SectionData.size())
270 PROTEUS_FATAL_ERROR("Expected offset within section size");
271
272 uint64_t *Data = (uint64_t *)(SectionData.data() + Offset);
273 if (!GVI.DevAddr)
274 PROTEUS_FATAL_ERROR("Cannot set global Var " + GlobalName +
275 " without a concrete device address");
276
277 *Data = reinterpret_cast<uint64_t>(GVI.DevAddr);
278 break;
279 }
280 }
281}
282
283inline void specializeIR(
284 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
286 const SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo,
287 bool SpecializeArgs, bool SpecializeDims, bool SpecializeDimsRange,
288 bool SpecializeLaunchBounds, int MinBlocksPerSM) {
289 Timer T;
290 Function *F = M.getFunction(FnName);
291
292 assert(F && "Expected non-null function!");
293 // Replace argument uses with runtime constants.
294 if (SpecializeArgs)
296
298 for (auto &[FnName, LambdaType] : LambdaCalleeInfo) {
299 const SmallVector<RuntimeConstant> &RCVec = LR.getJitVariables(LambdaType);
300 Function *F = M.getFunction(FnName);
301 if (!F)
302 PROTEUS_FATAL_ERROR("Expected non-null Function");
304 }
305
306 // Run the shared array transform after any value specialization (arguments,
307 // captures) to propagate any constants.
309
310 // Replace uses of blockDim.* and gridDim.* with constants.
311 if (SpecializeDims)
312 setKernelDims(M, GridDim, BlockDim);
313 if (SpecializeDimsRange)
314 setKernelDimsRange(M, GridDim, BlockDim);
315 F->setName(FnName + Suffix);
316
317 if (SpecializeLaunchBounds) {
318 int BlockSize = BlockDim.x * BlockDim.y * BlockDim.z;
319 auto TraceOut = [](int BlockSize, int MinBlocksPerSM) {
322 OS << "[LaunchBoundSpec] MaxThreads " << BlockSize << " MinBlocksPerSM "
323 << MinBlocksPerSM << "\n";
324
325 return S;
326 };
327 if (Config::get().ProteusTraceOutput >= 1)
328 Logger::trace(TraceOut(BlockSize, MinBlocksPerSM));
329 setLaunchBoundsForKernel(*F, BlockSize, MinBlocksPerSM);
330 }
331
333
335 << "specializeIR " << T.elapsed() << " ms\n");
336}
337
338} // namespace proteus
339
340#endif
341
342#endif
const void const char * VarName
Definition CompilerInterfaceDevice.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
#define PROTEUS_TIMER_OUTPUT(x)
Definition TimeTracing.hpp:57
static Config & get()
Definition Config.hpp:298
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:22
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.hpp:25
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static void transform(Module &M, Function &F, ArrayRef< RuntimeConstant > RCArray)
Definition TransformArgumentSpecialization.hpp:88
static void transform(Module &M, Function &F, const SmallVector< RuntimeConstant > &RCVec)
Definition TransformLambdaSpecialization.hpp:125
static void transform(Module &M)
Definition TransformSharedArray.hpp:30
const SmallVector< StringRef > & threadIdxXFnName()
Definition CoreLLVMCUDA.hpp:70
const SmallVector< StringRef > & gridDimYFnName()
Definition CoreLLVMCUDA.hpp:30
const SmallVector< StringRef > & threadIdxZFnName()
Definition CoreLLVMCUDA.hpp:80
const SmallVector< StringRef > & blockIdxZFnName()
Definition CoreLLVMCUDA.hpp:65
const SmallVector< StringRef > & gridDimZFnName()
Definition CoreLLVMCUDA.hpp:35
const SmallVector< StringRef > & gridDimXFnName()
Definition CoreLLVMCUDA.hpp:25
const SmallVector< StringRef > & blockIdxXFnName()
Definition CoreLLVMCUDA.hpp:55
const SmallVector< StringRef > & threadIdxYFnName()
Definition CoreLLVMCUDA.hpp:75
const SmallVector< StringRef > & blockIdxYFnName()
Definition CoreLLVMCUDA.hpp:60
const SmallVector< StringRef > & blockDimYFnName()
Definition CoreLLVMCUDA.hpp:45
const SmallVector< StringRef > & blockDimZFnName()
Definition CoreLLVMCUDA.hpp:50
const SmallVector< StringRef > & blockDimXFnName()
Definition CoreLLVMCUDA.hpp:40
Definition StorageCache.cpp:24
void setLaunchBoundsForKernel(Function &F, int MaxThreadsPerSM, int MinBlocksPerSM=0)
Definition CoreLLVMCUDA.hpp:87
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
std::string toString(CodegenOption Option)
Definition Config.hpp:26
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.hpp:230