1#ifndef PROTEUS_CORE_LLVM_DEVICE_HPP
2#define PROTEUS_CORE_LLVM_DEVICE_HPP
12#if defined(PROTEUS_ENABLE_HIP) || defined(PROTEUS_ENABLE_CUDA)
14#include <llvm/Analysis/CallGraph.h>
15#include <llvm/Bitcode/BitcodeReader.h>
16#include <llvm/Bitcode/BitcodeWriter.h>
17#include <llvm/IR/ReplaceConstant.h>
18#include <llvm/IR/Verifier.h>
19#include <llvm/Object/ELFObjectFile.h>
20#include <llvm/Transforms/Utils/Cloning.h>
30inline void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
31 auto ReplaceIntrinsicDim = [&](ArrayRef<StringRef> IntrinsicNames,
33 auto CollectCallUsers = [](Function &F) {
34 SmallVector<CallInst *> CallUsers;
35 for (
auto *User : F.users()) {
36 auto *Call = dyn_cast<CallInst>(User);
39 CallUsers.push_back(Call);
45 for (
auto IntrinsicName : IntrinsicNames) {
47 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
48 if (!IntrinsicFunction)
51 auto TraceOut = [](Function *F, Value *C) {
53 raw_svector_ostream OS(S);
54 OS <<
"[DimSpec] Replace call to " << F->getName() <<
" with constant "
60 for (
auto *Call : CollectCallUsers(*IntrinsicFunction)) {
61 Value *ConstantValue =
62 ConstantInt::get(Type::getInt32Ty(M.getContext()), DimValue);
63 Call->replaceAllUsesWith(ConstantValue);
66 Call->eraseFromParent();
79 auto InsertAssume = [&](ArrayRef<StringRef> IntrinsicNames,
int DimValue) {
80 for (
auto IntrinsicName : IntrinsicNames) {
81 Function *IntrinsicFunction = M.getFunction(IntrinsicName);
82 if (!IntrinsicFunction || IntrinsicFunction->use_empty())
85 auto TraceOut = [](Function *IntrinsicF,
int DimValue) {
87 raw_svector_ostream OS(S);
88 OS <<
"[DimSpec] Assume " << IntrinsicF->getName() <<
" with "
95 for (
auto *U : IntrinsicFunction->users()) {
96 auto *Call = dyn_cast<CallInst>(U);
101 IRBuilder<> Builder(Call->getNextNode());
102 Value *Bound = ConstantInt::get(Call->getType(), DimValue);
103 Value *Cmp = Builder.CreateICmpULT(Call, Bound);
105 Function *AssumeIntrinsic =
106 Intrinsic::getDeclaration(&M, Intrinsic::assume);
107 Builder.CreateCall(AssumeIntrinsic, Cmp);
125inline void replaceGlobalVariablesWithPointers(
127 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
130 for (
auto RegisterVar : VarNameToDevPtr) {
131 auto &
VarName = RegisterVar.first;
132 auto *GV = M.getNamedGlobal(
VarName);
139 convertUsersOfConstantsToInstructions({GV});
142 ConstantInt::get(Type::getInt64Ty(M.getContext()), 0xDEADBEEFDEADBEEF);
143 auto *CE = ConstantExpr::getIntToPtr(Addr, GV->getType()->getPointerTo());
144 auto *GVarPtr =
new GlobalVariable(
145 M, GV->getType()->getPointerTo(),
false, GlobalValue::ExternalLinkage,
146 CE, GV->getName() +
"$ptr",
nullptr, GV->getThreadLocalMode(),
147 GV->getAddressSpace(),
true);
150 SmallPtrSet<Value *, 16> ValuesToReplace;
151 SmallVector<Value *> Worklist;
153 Worklist.push_back(GV);
154 ValuesToReplace.insert(GV);
155 while (!Worklist.empty()) {
156 Value *V = Worklist.pop_back_val();
157 for (
auto *User : V->users()) {
158 if (
auto *C = dyn_cast<Constant>(User)) {
159 if (ValuesToReplace.insert(C).second)
160 Worklist.push_back(C);
166 if (isa<Instruction>(User))
170 "Expected Instruction or Constant user for Value: " +
toString(*V) +
175 for (Value *V : ValuesToReplace) {
176 SmallPtrSet<Instruction *, 16> Insts;
178 for (User *U : V->users()) {
179 if (
auto *I = dyn_cast<Instruction>(U)) {
185 for (
auto *I : Insts) {
186 IRBuilder Builder{I};
187 auto *Load = Builder.CreateLoad(GV->getType(), GVarPtr);
188 Value *Replacement = Load;
189 Type *ExpectedTy = V->getType();
190 if (Load->getType() != ExpectedTy)
192 Builder.CreatePointerBitCastOrAddrSpaceCast(Load, ExpectedTy);
194 I->replaceUsesOfWith(V, Replacement);
199#if PROTEUS_ENABLE_DEBUG
200 if (verifyModule(M, &errs()))
205inline void relinkGlobalsObject(
206 MemoryBufferRef Object,
207 const std::unordered_map<std::string, const void *> &VarNameToDevPtr) {
208 Expected<object::ELF64LEObjectFile> DeviceElfOrErr =
209 object::ELF64LEObjectFile::create(Object);
210 if (
auto E = DeviceElfOrErr.takeError())
213 auto &DeviceElf = *DeviceElfOrErr;
215 for (
auto &[GlobalName, DevPtr] : VarNameToDevPtr) {
216 for (
auto &Symbol : DeviceElf.symbols()) {
217 auto SymbolNameOrErr = Symbol.getName();
218 if (!SymbolNameOrErr)
220 auto SymbolName = *SymbolNameOrErr;
222 if (!SymbolName.equals(GlobalName +
"$ptr"))
225 Expected<uint64_t> ValueOrErr = Symbol.getValue();
228 uint64_t SymbolValue = *ValueOrErr;
231 auto SectionOrErr = Symbol.getSection();
234 const auto &Section = *SectionOrErr;
235 if (Section == DeviceElf.section_end())
239 Expected<StringRef> SectionDataOrErr = Section->getContents();
240 if (!SectionDataOrErr)
242 StringRef SectionData = *SectionDataOrErr;
245 uint64_t SectionAddr = Section->getAddress();
246 uint64_t Offset = SymbolValue - SectionAddr;
247 if (Offset >= SectionData.size())
250 uint64_t *Data = (uint64_t *)(SectionData.data() + Offset);
251 *Data =
reinterpret_cast<uint64_t
>(DevPtr);
257inline void specializeIR(
258 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
259 dim3 &GridDim, ArrayRef<RuntimeConstant> RCArray,
260 const SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo,
261 bool SpecializeArgs,
bool SpecializeDims,
bool SpecializeLaunchBounds) {
263 Function *F = M.getFunction(FnName);
265 assert(F &&
"Expected non-null function!");
271 for (
auto &[FnName, LambdaType] : LambdaCalleeInfo) {
272 const SmallVector<RuntimeConstant> &RCVec = LR.getJitVariables(LambdaType);
273 Function *F = M.getFunction(FnName);
285 setKernelDims(M, GridDim, BlockDim);
287 F->setName(FnName + Suffix);
289 if (SpecializeLaunchBounds) {
290 size_t GridSize = GridDim.x * GridDim.y * GridDim.z;
291 int BlockSize = BlockDim.x * BlockDim.y * BlockDim.z;
292 auto TraceOut = [](
size_t GridSize,
int BlockSize) {
294 raw_svector_ostream OS(S);
295 OS <<
"[LaunchBoundSpec] GridSize " << GridSize <<
" BlockSize "
296 << BlockSize <<
"\n";
308 <<
"specializeIR " << T.elapsed() <<
" ms\n");
const char * VarName
Definition CompilerInterfaceDevice.cpp:20
#define PROTEUS_DBG(x)
Definition Debug.h:10
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
#define PROTEUS_TIMER_OUTPUT(x)
Definition TimeTracing.hpp:57
static Config & get()
Definition Config.hpp:112
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:21
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.hpp:25
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static void logfile(const std::string &Filename, T &&Data)
Definition Logger.hpp:33
const SmallVector< StringRef > & threadIdxXFnName()
Definition CoreLLVMCUDA.hpp:70
const SmallVector< StringRef > & gridDimYFnName()
Definition CoreLLVMCUDA.hpp:30
const SmallVector< StringRef > & threadIdxZFnName()
Definition CoreLLVMCUDA.hpp:80
const SmallVector< StringRef > & blockIdxZFnName()
Definition CoreLLVMCUDA.hpp:65
const SmallVector< StringRef > & gridDimZFnName()
Definition CoreLLVMCUDA.hpp:35
const SmallVector< StringRef > & gridDimXFnName()
Definition CoreLLVMCUDA.hpp:25
const SmallVector< StringRef > & blockIdxXFnName()
Definition CoreLLVMCUDA.hpp:55
const SmallVector< StringRef > & threadIdxYFnName()
Definition CoreLLVMCUDA.hpp:75
const SmallVector< StringRef > & blockIdxYFnName()
Definition CoreLLVMCUDA.hpp:60
const SmallVector< StringRef > & blockDimYFnName()
Definition CoreLLVMCUDA.hpp:45
const SmallVector< StringRef > & blockDimZFnName()
Definition CoreLLVMCUDA.hpp:50
const SmallVector< StringRef > & blockDimXFnName()
Definition CoreLLVMCUDA.hpp:40
Definition Dispatcher.cpp:14
void setLaunchBoundsForKernel(Module &M, Function &F, size_t, int BlockSize)
Definition CoreLLVMCUDA.hpp:87
std::string toString(CodegenOption Option)
Definition Config.hpp:23
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.hpp:178