Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CoreLLVMDevice.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_CORE_LLVM_DEVICE_HPP
2#define PROTEUS_CORE_LLVM_DEVICE_HPP
3
4#if PROTEUS_ENABLE_HIP
6#endif
7
8#if PROTEUS_ENABLE_CUDA
10#endif
11
12#if defined(PROTEUS_ENABLE_HIP) || defined(PROTEUS_ENABLE_CUDA)
13
14#include <llvm/Analysis/CallGraph.h>
15#include <llvm/Bitcode/BitcodeReader.h>
16#include <llvm/Bitcode/BitcodeWriter.h>
17#include <llvm/IR/ReplaceConstant.h>
18#include <llvm/IR/Verifier.h>
19#include <llvm/Object/ELFObjectFile.h>
20#include <llvm/Transforms/Utils/Cloning.h>
21
27
28namespace proteus {
29
30inline void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
33 auto CollectCallUsers = [](Function &F) {
35 for (auto *User : F.users()) {
37 if (!Call)
38 continue;
39 CallUsers.push_back(Call);
40 }
41
42 return CallUsers;
43 };
44
45 for (auto IntrinsicName : IntrinsicNames) {
46
49 continue;
50
51 auto TraceOut = [](Function *F, Value *C) {
54 OS << "[DimSpec] Replace call to " << F->getName() << " with constant "
55 << *C << "\n";
56
57 return S;
58 };
59
61 Value *ConstantValue =
62 ConstantInt::get(Type::getInt32Ty(M.getContext()), DimValue);
63 Call->replaceAllUsesWith(ConstantValue);
64 if (Config::get().ProteusTraceOutput >= 1)
66 Call->eraseFromParent();
67 }
68 }
69 };
70
74
78}
79
80inline void setKernelDimsAssume(Module &M, dim3 &GridDim, dim3 &BlockDim) {
82 for (auto IntrinsicName : IntrinsicNames) {
84 if (!IntrinsicFunction || IntrinsicFunction->use_empty())
85 continue;
86
87 auto TraceOut = [](Function *IntrinsicF, int DimValue) {
90 OS << "[DimSpec] Assume " << IntrinsicF->getName() << " with "
91 << DimValue << "\n";
92
93 return S;
94 };
95
96 // Iterate over all uses of the intrinsic.
97 for (auto *U : IntrinsicFunction->users()) {
98 auto *Call = dyn_cast<CallInst>(U);
99 if (!Call)
100 continue;
101
102 // Insert the llvm.assume intrinsic.
103 IRBuilder<> Builder(Call->getNextNode());
104 Value *Bound = ConstantInt::get(Call->getType(), DimValue);
105 Value *Cmp = Builder.CreateICmpULT(Call, Bound);
106
108 Intrinsic::getDeclaration(&M, Intrinsic::assume);
109 Builder.CreateCall(AssumeIntrinsic, Cmp);
110 if (Config::get().ProteusTraceOutput >= 1)
112 }
113 }
114 };
115
116 // Inform LLVM about the range of possible values of threadIdx.*.
120
121 // Inform LLVdetailut the range of possible values of blockIdx.*.
125}
126
127inline void replaceGlobalVariablesWithPointers(
128 Module &M,
129 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
130 // Re-link globals to fixed addresses provided by registered
131 // variables.
132 for (auto RegisterVar : VarNameToDevPtr) {
133 auto &VarName = RegisterVar.first;
134 auto *GV = M.getNamedGlobal(VarName);
135 // Skip linking if the GV does not exist in the module.
136 if (!GV)
137 continue;
138
139 // This will convert constant users of GV to instructions so that we can
140 // replace with the GV ptr.
142
143 Constant *Addr =
144 ConstantInt::get(Type::getInt64Ty(M.getContext()), 0xDEADBEEFDEADBEEF);
145 auto *CE = ConstantExpr::getIntToPtr(Addr, GV->getType()->getPointerTo());
146 auto *GVarPtr = new GlobalVariable(
147 M, GV->getType()->getPointerTo(), false, GlobalValue::ExternalLinkage,
148 CE, GV->getName() + "$ptr", nullptr, GV->getThreadLocalMode(),
149 GV->getAddressSpace(), true);
150
151 // Find all Constant users that refer to the global variable.
153 SmallVector<Value *> Worklist;
154 // Seed with the global variable.
155 Worklist.push_back(GV);
156 ValuesToReplace.insert(GV);
157 while (!Worklist.empty()) {
158 Value *V = Worklist.pop_back_val();
159 for (auto *User : V->users()) {
160 if (auto *C = dyn_cast<Constant>(User)) {
161 if (ValuesToReplace.insert(C).second)
162 Worklist.push_back(C);
163
164 continue;
165 }
166
167 // Skip instructions to be handled when replacing.
169 continue;
170
172 "Expected Instruction or Constant user for Value: " + toString(*V) +
173 " , User: " + toString(*User));
174 }
175 }
176
177 for (Value *V : ValuesToReplace) {
179 // Find instruction users to replace value.
180 for (User *U : V->users()) {
181 if (auto *I = dyn_cast<Instruction>(U)) {
182 Insts.insert(I);
183 }
184 }
185
186 // Replace value in instructions.
187 for (auto *I : Insts) {
189 auto *Load = Builder.CreateLoad(GV->getType(), GVarPtr);
190 Value *Replacement = Load;
191 Type *ExpectedTy = V->getType();
192 if (Load->getType() != ExpectedTy)
194 Builder.CreatePointerBitCastOrAddrSpaceCast(Load, ExpectedTy);
195
196 I->replaceUsesOfWith(V, Replacement);
197 }
198 }
199 }
200
201 if (Config::get().ProteusDebugOutput) {
202 if (verifyModule(M, &errs()))
203 PROTEUS_FATAL_ERROR("Broken module found, JIT compilation aborted!");
204 }
205}
206
207inline void relinkGlobalsObject(
209 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
211 object::ELF64LEObjectFile::create(Object);
212 if (auto E = DeviceElfOrErr.takeError())
213 PROTEUS_FATAL_ERROR("Cannot create the device elf: " +
214 toString(std::move(E)));
215 auto &DeviceElf = *DeviceElfOrErr;
216
217 for (auto &[GlobalName, DevPtr] : VarNameToDevPtr) {
218 for (auto &Symbol : DeviceElf.symbols()) {
219 auto SymbolNameOrErr = Symbol.getName();
220 if (!SymbolNameOrErr)
221 continue;
223
224 if (!(SymbolName == (GlobalName + "$ptr")))
225 continue;
226
228 if (!ValueOrErr)
229 PROTEUS_FATAL_ERROR("Expected symbol value");
231
232 // Get the section containing the symbol
233 auto SectionOrErr = Symbol.getSection();
234 if (!SectionOrErr)
235 PROTEUS_FATAL_ERROR("Cannot retrieve section");
236 const auto &Section = *SectionOrErr;
237 if (Section == DeviceElf.section_end())
238 PROTEUS_FATAL_ERROR("Expected sybmol in section");
239
240 // Get the section's address and data
242 if (!SectionDataOrErr)
243 PROTEUS_FATAL_ERROR("Error retrieving section data");
245
246 // Calculate offset within the section
247 uint64_t SectionAddr = Section->getAddress();
249 if (Offset >= SectionData.size())
250 PROTEUS_FATAL_ERROR("Expected offset within section size");
251
252 uint64_t *Data = (uint64_t *)(SectionData.data() + Offset);
253 *Data = reinterpret_cast<uint64_t>(DevPtr);
254 break;
255 }
256 }
257}
258
259inline void specializeIR(
260 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
262 const SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo,
263 bool SpecializeArgs, bool SpecializeDims, bool SpecializeDimsAssume,
264 bool SpecializeLaunchBounds) {
265 Timer T;
266 Function *F = M.getFunction(FnName);
267
268 assert(F && "Expected non-null function!");
269 // Replace argument uses with runtime constants.
270 if (SpecializeArgs)
272
274 for (auto &[FnName, LambdaType] : LambdaCalleeInfo) {
275 const SmallVector<RuntimeConstant> &RCVec = LR.getJitVariables(LambdaType);
276 Function *F = M.getFunction(FnName);
277 if (!F)
278 PROTEUS_FATAL_ERROR("Expected non-null Function");
280 }
281
282 // Run the shared array transform after any value specialization (arguments,
283 // captures) to propagate any constants.
285
286 // Replace uses of blockDim.* and gridDim.* with constants.
287 if (SpecializeDims)
288 setKernelDims(M, GridDim, BlockDim);
289 if (SpecializeDimsAssume)
290 setKernelDimsAssume(M, GridDim, BlockDim);
291 F->setName(FnName + Suffix);
292
293 if (SpecializeLaunchBounds) {
294 int BlockSize = BlockDim.x * BlockDim.y * BlockDim.z;
295 auto TraceOut = [](int BlockSize) {
298 OS << "[LaunchBoundSpec] BlockSize " << BlockSize << "\n";
299
300 return S;
301 };
302 if (Config::get().ProteusTraceOutput >= 1)
305 }
306
308
310 << "specializeIR " << T.elapsed() << " ms\n");
311}
312
313} // namespace proteus
314
315#endif
316
317#endif
const char * VarName
Definition CompilerInterfaceDevice.cpp:20
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
#define PROTEUS_TIMER_OUTPUT(x)
Definition TimeTracing.hpp:57
static Config & get()
Definition Config.hpp:284
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:22
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.hpp:25
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static void transform(Module &M, Function &F, ArrayRef< RuntimeConstant > RCArray)
Definition TransformArgumentSpecialization.hpp:88
static void transform(Module &M, Function &F, const SmallVector< RuntimeConstant > &RCVec)
Definition TransformLambdaSpecialization.hpp:125
static void transform(Module &M)
Definition TransformSharedArray.hpp:30
const SmallVector< StringRef > & threadIdxXFnName()
Definition CoreLLVMCUDA.hpp:70
const SmallVector< StringRef > & gridDimYFnName()
Definition CoreLLVMCUDA.hpp:30
const SmallVector< StringRef > & threadIdxZFnName()
Definition CoreLLVMCUDA.hpp:80
const SmallVector< StringRef > & blockIdxZFnName()
Definition CoreLLVMCUDA.hpp:65
const SmallVector< StringRef > & gridDimZFnName()
Definition CoreLLVMCUDA.hpp:35
const SmallVector< StringRef > & gridDimXFnName()
Definition CoreLLVMCUDA.hpp:25
const SmallVector< StringRef > & blockIdxXFnName()
Definition CoreLLVMCUDA.hpp:55
const SmallVector< StringRef > & threadIdxYFnName()
Definition CoreLLVMCUDA.hpp:75
const SmallVector< StringRef > & blockIdxYFnName()
Definition CoreLLVMCUDA.hpp:60
const SmallVector< StringRef > & blockDimYFnName()
Definition CoreLLVMCUDA.hpp:45
const SmallVector< StringRef > & blockDimZFnName()
Definition CoreLLVMCUDA.hpp:50
const SmallVector< StringRef > & blockDimXFnName()
Definition CoreLLVMCUDA.hpp:40
Definition BuiltinsCUDA.cpp:4
void setLaunchBoundsForKernel(Function &F, int MaxThreadsPerSM, int MinBlocksPerSM=0)
Definition CoreLLVMCUDA.hpp:87
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
std::string toString(CodegenOption Option)
Definition Config.hpp:26
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.hpp:230