Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CoreLLVMDevice.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_CORE_LLVM_DEVICE_HPP
2#define PROTEUS_CORE_LLVM_DEVICE_HPP
3
4#if PROTEUS_ENABLE_HIP
6#endif
7
8#if PROTEUS_ENABLE_CUDA
10#endif
11
12#if defined(PROTEUS_ENABLE_HIP) || defined(PROTEUS_ENABLE_CUDA)
13
14#include <llvm/Analysis/CallGraph.h>
15#include <llvm/Bitcode/BitcodeReader.h>
16#include <llvm/Bitcode/BitcodeWriter.h>
17#include <llvm/IR/ReplaceConstant.h>
18#include <llvm/IR/Verifier.h>
19#include <llvm/Object/ELFObjectFile.h>
20#include <llvm/Transforms/Utils/Cloning.h>
21
27
28namespace proteus {
29
30inline void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
31 auto ReplaceIntrinsicDim = [&](ArrayRef<StringRef> IntrinsicNames,
32 uint32_t DimValue) {
33 auto CollectCallUsers = [](Function &F) {
34 SmallVector<CallInst *> CallUsers;
35 for (auto *User : F.users()) {
36 auto *Call = dyn_cast<CallInst>(User);
37 if (!Call)
38 continue;
39 CallUsers.push_back(Call);
40 }
41
42 return CallUsers;
43 };
44
45 for (auto IntrinsicName : IntrinsicNames) {
46
47 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
48 if (!IntrinsicFunction)
49 continue;
50
51 auto TraceOut = [](Function *F, Value *C) {
52 SmallString<128> S;
53 raw_svector_ostream OS(S);
54 OS << "[DimSpec] Replace call to " << F->getName() << " with constant "
55 << *C << "\n";
56
57 return S;
58 };
59
60 for (auto *Call : CollectCallUsers(*IntrinsicFunction)) {
61 Value *ConstantValue =
62 ConstantInt::get(Type::getInt32Ty(M.getContext()), DimValue);
63 Call->replaceAllUsesWith(ConstantValue);
64 if (Config::get().ProteusTraceOutput)
65 Logger::trace(TraceOut(IntrinsicFunction, ConstantValue));
66 Call->eraseFromParent();
67 }
68 }
69 };
70
71 ReplaceIntrinsicDim(detail::gridDimXFnName(), GridDim.x);
72 ReplaceIntrinsicDim(detail::gridDimYFnName(), GridDim.y);
73 ReplaceIntrinsicDim(detail::gridDimZFnName(), GridDim.z);
74
75 ReplaceIntrinsicDim(detail::blockDimXFnName(), BlockDim.x);
76 ReplaceIntrinsicDim(detail::blockDimYFnName(), BlockDim.y);
77 ReplaceIntrinsicDim(detail::blockDimZFnName(), BlockDim.z);
78
79 auto InsertAssume = [&](ArrayRef<StringRef> IntrinsicNames, int DimValue) {
80 for (auto IntrinsicName : IntrinsicNames) {
81 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
82 if (!IntrinsicFunction || IntrinsicFunction->use_empty())
83 continue;
84
85 auto TraceOut = [](Function *IntrinsicF, int DimValue) {
86 SmallString<128> S;
87 raw_svector_ostream OS(S);
88 OS << "[DimSpec] Assume " << IntrinsicF->getName() << " with "
89 << DimValue << "\n";
90
91 return S;
92 };
93
94 // Iterate over all uses of the intrinsic.
95 for (auto *U : IntrinsicFunction->users()) {
96 auto *Call = dyn_cast<CallInst>(U);
97 if (!Call)
98 continue;
99
100 // Insert the llvm.assume intrinsic.
101 IRBuilder<> Builder(Call->getNextNode());
102 Value *Bound = ConstantInt::get(Call->getType(), DimValue);
103 Value *Cmp = Builder.CreateICmpULT(Call, Bound);
104
105 Function *AssumeIntrinsic =
106 Intrinsic::getDeclaration(&M, Intrinsic::assume);
107 Builder.CreateCall(AssumeIntrinsic, Cmp);
108 if (Config::get().ProteusTraceOutput)
109 Logger::trace(TraceOut(IntrinsicFunction, DimValue));
110 }
111 }
112 };
113
114 // Inform LLVM about the range of possible values of threadIdx.*.
115 InsertAssume(detail::threadIdxXFnName(), BlockDim.x);
116 InsertAssume(detail::threadIdxYFnName(), BlockDim.y);
117 InsertAssume(detail::threadIdxZFnName(), BlockDim.z);
118
119 // Inform LLVdetailut the range of possible values of blockIdx.*.
120 InsertAssume(detail::blockIdxXFnName(), GridDim.x);
121 InsertAssume(detail::blockIdxYFnName(), GridDim.y);
122 InsertAssume(detail::blockIdxZFnName(), GridDim.z);
123}
124
125inline void replaceGlobalVariablesWithPointers(
126 Module &M,
127 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
128 // Re-link globals to fixed addresses provided by registered
129 // variables.
130 for (auto RegisterVar : VarNameToDevPtr) {
131 auto &VarName = RegisterVar.first;
132 auto *GV = M.getNamedGlobal(VarName);
133 // Skip linking if the GV does not exist in the module.
134 if (!GV)
135 continue;
136
137 // This will convert constant users of GV to instructions so that we can
138 // replace with the GV ptr.
139 convertUsersOfConstantsToInstructions({GV});
140
141 Constant *Addr =
142 ConstantInt::get(Type::getInt64Ty(M.getContext()), 0xDEADBEEFDEADBEEF);
143 auto *CE = ConstantExpr::getIntToPtr(Addr, GV->getType()->getPointerTo());
144 auto *GVarPtr = new GlobalVariable(
145 M, GV->getType()->getPointerTo(), false, GlobalValue::ExternalLinkage,
146 CE, GV->getName() + "$ptr", nullptr, GV->getThreadLocalMode(),
147 GV->getAddressSpace(), true);
148
149 // Find all Constant users that refer to the global variable.
150 SmallPtrSet<Value *, 16> ValuesToReplace;
151 SmallVector<Value *> Worklist;
152 // Seed with the global variable.
153 Worklist.push_back(GV);
154 ValuesToReplace.insert(GV);
155 while (!Worklist.empty()) {
156 Value *V = Worklist.pop_back_val();
157 for (auto *User : V->users()) {
158 if (auto *C = dyn_cast<Constant>(User)) {
159 if (ValuesToReplace.insert(C).second)
160 Worklist.push_back(C);
161
162 continue;
163 }
164
165 // Skip instructions to be handled when replacing.
166 if (isa<Instruction>(User))
167 continue;
168
170 "Expected Instruction or Constant user for Value: " + toString(*V) +
171 " , User: " + toString(*User));
172 }
173 }
174
175 for (Value *V : ValuesToReplace) {
176 SmallPtrSet<Instruction *, 16> Insts;
177 // Find instruction users to replace value.
178 for (User *U : V->users()) {
179 if (auto *I = dyn_cast<Instruction>(U)) {
180 Insts.insert(I);
181 }
182 }
183
184 // Replace value in instructions.
185 for (auto *I : Insts) {
186 IRBuilder Builder{I};
187 auto *Load = Builder.CreateLoad(GV->getType(), GVarPtr);
188 Value *Replacement = Load;
189 Type *ExpectedTy = V->getType();
190 if (Load->getType() != ExpectedTy)
191 Replacement =
192 Builder.CreatePointerBitCastOrAddrSpaceCast(Load, ExpectedTy);
193
194 I->replaceUsesOfWith(V, Replacement);
195 }
196 }
197 }
198
199#if PROTEUS_ENABLE_DEBUG
200 if (verifyModule(M, &errs()))
201 PROTEUS_FATAL_ERROR("Broken module found, JIT compilation aborted!");
202#endif
203}
204
205inline void relinkGlobalsObject(
206 MemoryBufferRef Object,
207 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
208 Expected<object::ELF64LEObjectFile> DeviceElfOrErr =
209 object::ELF64LEObjectFile::create(Object);
210 if (auto E = DeviceElfOrErr.takeError())
211 PROTEUS_FATAL_ERROR("Cannot create the device elf: " +
212 toString(std::move(E)));
213 auto &DeviceElf = *DeviceElfOrErr;
214
215 for (auto &[GlobalName, DevPtr] : VarNameToDevPtr) {
216 for (auto &Symbol : DeviceElf.symbols()) {
217 auto SymbolNameOrErr = Symbol.getName();
218 if (!SymbolNameOrErr)
219 continue;
220 auto SymbolName = *SymbolNameOrErr;
221
222 if (!SymbolName.equals(GlobalName + "$ptr"))
223 continue;
224
225 Expected<uint64_t> ValueOrErr = Symbol.getValue();
226 if (!ValueOrErr)
227 PROTEUS_FATAL_ERROR("Expected symbol value");
228 uint64_t SymbolValue = *ValueOrErr;
229
230 // Get the section containing the symbol
231 auto SectionOrErr = Symbol.getSection();
232 if (!SectionOrErr)
233 PROTEUS_FATAL_ERROR("Cannot retrieve section");
234 const auto &Section = *SectionOrErr;
235 if (Section == DeviceElf.section_end())
236 PROTEUS_FATAL_ERROR("Expected sybmol in section");
237
238 // Get the section's address and data
239 Expected<StringRef> SectionDataOrErr = Section->getContents();
240 if (!SectionDataOrErr)
241 PROTEUS_FATAL_ERROR("Error retrieving section data");
242 StringRef SectionData = *SectionDataOrErr;
243
244 // Calculate offset within the section
245 uint64_t SectionAddr = Section->getAddress();
246 uint64_t Offset = SymbolValue - SectionAddr;
247 if (Offset >= SectionData.size())
248 PROTEUS_FATAL_ERROR("Expected offset within section size");
249
250 uint64_t *Data = (uint64_t *)(SectionData.data() + Offset);
251 *Data = reinterpret_cast<uint64_t>(DevPtr);
252 break;
253 }
254 }
255}
256
257inline void specializeIR(
258 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
259 dim3 &GridDim, ArrayRef<RuntimeConstant> RCArray,
260 const SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo,
261 bool SpecializeArgs, bool SpecializeDims, bool SpecializeLaunchBounds) {
262 Timer T;
263 Function *F = M.getFunction(FnName);
264
265 assert(F && "Expected non-null function!");
266 // Replace argument uses with runtime constants.
267 if (SpecializeArgs)
269
270 auto &LR = LambdaRegistry::instance();
271 for (auto &[FnName, LambdaType] : LambdaCalleeInfo) {
272 const SmallVector<RuntimeConstant> &RCVec = LR.getJitVariables(LambdaType);
273 Function *F = M.getFunction(FnName);
274 if (!F)
275 PROTEUS_FATAL_ERROR("Expected non-null Function");
277 }
278
279 // Run the shared array transform after any value specialization (arguments,
280 // captures) to propagate any constants.
282
283 // Replace uses of blockDim.* and gridDim.* with constants.
284 if (SpecializeDims)
285 setKernelDims(M, GridDim, BlockDim);
286
287 F->setName(FnName + Suffix);
288
289 if (SpecializeLaunchBounds) {
290 size_t GridSize = GridDim.x * GridDim.y * GridDim.z;
291 int BlockSize = BlockDim.x * BlockDim.y * BlockDim.z;
292 auto TraceOut = [](size_t GridSize, int BlockSize) {
293 SmallString<128> S;
294 raw_svector_ostream OS(S);
295 OS << "[LaunchBoundSpec] GridSize " << GridSize << " BlockSize "
296 << BlockSize << "\n";
297
298 return S;
299 };
300 if (Config::get().ProteusTraceOutput)
301 Logger::trace(TraceOut(GridSize, BlockSize));
302 setLaunchBoundsForKernel(M, *F, GridSize, BlockSize);
303 }
304
306
308 << "specializeIR " << T.elapsed() << " ms\n");
309 PROTEUS_DBG(Logger::logfile(FnName.str() + ".specialized.ll", M));
310}
311
312} // namespace proteus
313
314#endif
315
316#endif
const char * VarName
Definition CompilerInterfaceDevice.cpp:20
#define PROTEUS_DBG(x)
Definition Debug.h:10
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
#define PROTEUS_TIMER_OUTPUT(x)
Definition TimeTracing.hpp:57
static Config & get()
Definition Config.hpp:112
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:21
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.hpp:25
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static void logfile(const std::string &Filename, T &&Data)
Definition Logger.hpp:33
static void transform(Module &M, Function &F, ArrayRef< RuntimeConstant > RCArray)
Definition TransformArgumentSpecialization.hpp:27
static void transform(Module &M, Function &F, const SmallVector< RuntimeConstant > &RCVec)
Definition TransformLambdaSpecialization.hpp:59
static void transform(Module &M)
Definition TransformSharedArray.hpp:30
const SmallVector< StringRef > & threadIdxXFnName()
Definition CoreLLVMCUDA.hpp:70
const SmallVector< StringRef > & gridDimYFnName()
Definition CoreLLVMCUDA.hpp:30
const SmallVector< StringRef > & threadIdxZFnName()
Definition CoreLLVMCUDA.hpp:80
const SmallVector< StringRef > & blockIdxZFnName()
Definition CoreLLVMCUDA.hpp:65
const SmallVector< StringRef > & gridDimZFnName()
Definition CoreLLVMCUDA.hpp:35
const SmallVector< StringRef > & gridDimXFnName()
Definition CoreLLVMCUDA.hpp:25
const SmallVector< StringRef > & blockIdxXFnName()
Definition CoreLLVMCUDA.hpp:55
const SmallVector< StringRef > & threadIdxYFnName()
Definition CoreLLVMCUDA.hpp:75
const SmallVector< StringRef > & blockIdxYFnName()
Definition CoreLLVMCUDA.hpp:60
const SmallVector< StringRef > & blockDimYFnName()
Definition CoreLLVMCUDA.hpp:45
const SmallVector< StringRef > & blockDimZFnName()
Definition CoreLLVMCUDA.hpp:50
const SmallVector< StringRef > & blockDimXFnName()
Definition CoreLLVMCUDA.hpp:40
Definition Dispatcher.cpp:14
void setLaunchBoundsForKernel(Module &M, Function &F, size_t, int BlockSize)
Definition CoreLLVMCUDA.hpp:87
std::string toString(CodegenOption Option)
Definition Config.hpp:23
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.hpp:178