11#ifndef PROTEUS_JITENGINEDEVICE_HPP
12#define PROTEUS_JITENGINEDEVICE_HPP
16#include <llvm/ADT/SmallPtrSet.h>
17#include <llvm/Analysis/CallGraph.h>
22#include <llvm/ADT/SmallVector.h>
23#include <llvm/ADT/StringRef.h>
24#include <llvm/Analysis/TargetTransformInfo.h>
25#include <llvm/Bitcode/BitcodeWriter.h>
26#include <llvm/CodeGen/CommandFlags.h>
27#include <llvm/CodeGen/MachineModuleInfo.h>
28#include <llvm/Config/llvm-config.h>
29#include <llvm/Demangle/Demangle.h>
30#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
31#include <llvm/IR/Constants.h>
32#include <llvm/IR/GlobalVariable.h>
33#include <llvm/IR/Instruction.h>
34#include <llvm/IR/Instructions.h>
35#include <llvm/IR/LLVMContext.h>
36#include <llvm/IR/LegacyPassManager.h>
37#include <llvm/IR/Module.h>
38#include <llvm/IR/ReplaceConstant.h>
39#include <llvm/IR/Type.h>
40#include <llvm/IR/Verifier.h>
41#include <llvm/IRReader/IRReader.h>
42#include <llvm/Linker/Linker.h>
43#include <llvm/MC/TargetRegistry.h>
44#include <llvm/Object/ELFObjectFile.h>
45#include <llvm/Passes/PassBuilder.h>
46#include <llvm/Support/Error.h>
47#include <llvm/Support/MemoryBuffer.h>
48#include <llvm/Support/MemoryBufferRef.h>
49#include <llvm/Target/TargetMachine.h>
50#include <llvm/Transforms/IPO/Internalize.h>
51#include <llvm/Transforms/Utils/Cloning.h>
52#include <llvm/Transforms/Utils/ModuleUtils.h>
82 std::unique_ptr<LLVMContext> Ctx;
85 std::optional<SmallVector<std::unique_ptr<Module>>> ExtractedModules;
86 std::optional<HashT> ExtractedModuleHash;
87 std::optional<CallGraph> ModuleCallGraph;
88 std::unique_ptr<MemoryBuffer> DeviceBinary;
89 std::unordered_map<std::string, GlobalVarInfo> VarNameToGlobalInfo;
98 LinkedModuleIds(LinkedModuleIds), LinkedModule(
nullptr),
115 if (ExtractedModules->size() == 1) {
116 LinkedModule = ExtractedModules->front().get();
117 if (
auto E = LinkedModule->materializeAll())
130 LinkedModule = ExtractedModules->front().get();
134 <<
"getLinkedModule " <<
T.elapsed() <<
" ms\n");
137 return *LinkedModule;
146 for (
auto &M : ExtractedModules.value())
152 ExtractedModules = std::move(
Modules);
159 if (ExtractedModuleHash)
160 ExtractedModuleHash =
hashCombine(ExtractedModuleHash.value(), HashValue);
162 ExtractedModuleHash = HashValue;
166 if (!ModuleCallGraph.has_value()) {
169 ModuleCallGraph.emplace(
CallGraph(*LinkedModule));
171 return ModuleCallGraph.value();
178 return DeviceBinary->getMemBufferRef();
185 LinkedModuleIds.push_back(
ModuleId);
194 std::call_once(Flag, [&]() {
199 auto TraceOut = [](std::unordered_map<std::string, GlobalVarInfo>
200 &VarNameToGlobalInfo) {
205 <<
" DevAddr:" <<
GVI.DevAddr <<
" VarSize:" <<
GVI.VarSize
213 GlobalsMapped =
true;
218 return VarNameToGlobalInfo;
225 std::optional<void *> Kernel;
226 std::unique_ptr<LLVMContext> Ctx;
229 std::optional<std::unique_ptr<Module>> ExtractedModule;
230 std::optional<std::unique_ptr<MemoryBuffer>> Bitcode;
231 std::optional<std::reference_wrapper<BinaryInfo>> BinInfo;
232 std::optional<HashT> StaticHash;
233 std::optional<SmallVector<std::pair<std::string, StringRef>>>
246 assert(
Kernel.has_value() &&
"Expected Kernel is inited");
250 const std::string &
getName()
const {
return Name; }
252 bool hasModule()
const {
return ExtractedModule.has_value(); }
256 ExtractedModule = std::move(Mod);
268 StaticHash =
hash(Name);
269 StaticHash =
hashCombine(StaticHash.value(), ModuleHash);
293 std::pair<std::unique_ptr<Module>, std::unique_ptr<MemoryBuffer>>
297 static_cast<ImplT &
>(*this).tryExtractKernelModule(BinInfo,
KernelName,
299 std::unique_ptr<MemoryBuffer> Bitcode =
nullptr;
306 static_cast<ImplT &
>(*this).extractModules(BinInfo);
333 <<
T.elapsed() <<
" ms\n");
354 Bitcode = MemoryBuffer::getMemBufferCopy(
CloneStr);
362 Bitcode = MemoryBuffer::getMemBufferCopy(
BitcodeStr);
365 return std::make_pair(std::move(
KernelModule), std::move(Bitcode));
394 <<
"Extract kernel module " <<
T.elapsed() <<
" ms\n");
428 <<
"=== LAMBDA MATCHING\n"
429 <<
"Caller trigger " <<
KernelInfo.getName() <<
" -> "
435 <<
" Trying F " <<
demangle(F.getName().str()) <<
"\n ");
439 LambdaCalleeInfo.emplace_back(F.getName(),
443 KernelInfo.setLambdaCalleeInfo(std::move(LambdaCalleeInfo));
481 std::optional<std::reference_wrapper<JITKernelInfo>>
500 HashT ModuleHash =
static_cast<ImplT &
>(*this).getModuleHash(BinInfo);
518 void *resolveDeviceGlobalAddr(
const void *
Addr) {
519 return static_cast<ImplT &
>(*this).resolveDeviceGlobalAddr(
Addr);
523 proteus::setKernelDims(M, GridDim, BlockDim);
531 return static_cast<ImplT &
>(*this).launchKernelFunction(
536 const std::unordered_map<std::string, GlobalVarInfo>
537 &VarNameToGlobalInfo) {
539 proteus::relinkGlobalsObject(
Object, VarNameToGlobalInfo);
544 std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo) {
546 return static_cast<ImplT &
>(*this).getKernelFunctionFromImage(
578template <
typename ImplT>
583template <
typename ImplT>
596 BinInfo.mapGlobals();
606 GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z);
609 CodeCache.lookup(HashValue);
618 std::string Suffix = mangleSuffix(HashValue);
625 relinkGlobalsObject(
CompiledLib->ObjectModule->getMemBufferRef(),
626 BinInfo.getVarNameToGlobalInfo());
630 BinInfo.getVarNameToGlobalInfo());
640 std::unique_ptr<MemoryBuffer>
ObjBuf =
nullptr;
646 if (!
Compiler.isCompilationPending(HashValue)) {
652 GridDim, RCVec,
KernelInfo.getLambdaCalleeInfo(),
653 BinInfo.getVarNameToGlobalInfo(), GlobalLinkedBinaries, DeviceArch,
663 HashValue,
Config::get().ProteusAsyncTestBlocking);
672 GridDim, RCVec,
KernelInfo.getLambdaCalleeInfo(),
673 BinInfo.getVarNameToGlobalInfo(), GlobalLinkedBinaries, DeviceArch,
685 BinInfo.getVarNameToGlobalInfo());
689 ObjectCache.store(HashValue,
ObjBuf->getMemBufferRef());
696template <
typename ImplT>
702 <<
"Register fatbinary Handle " <<
Handle <<
" FatbinWrapper "
704 <<
" ModuleId " <<
ModuleId <<
"\n");
713 for (
int I = 0;
Ptr !=
nullptr;
716 <<
"I " <<
I <<
" PrelinkedFatbin " <<
Ptr <<
"\n");
717 GlobalLinkedBinaries.insert(
Ptr);
738template <
typename ImplT>
743 <<
" To Handle " <<
Handle <<
"\n");
748 if (JITKernelInfoMap.contains(
Kernel)) {
750 <<
"Warning: duplicate register function for kernel " +
756 if (!HandleToBinaryInfo.count(
Handle))
761 <<
"Register function " <<
KernelName <<
" with binary handle "
764 JITKernelInfoMap[
Kernel] =
768template <
typename ImplT>
773 <<
" Binary " << (
void *)
FatbinWrapper->Binary <<
" ModuleId "
776 if (!HandleToBinaryInfo.count(CurHandle))
779 HandleToBinaryInfo[CurHandle].addModuleId(
ModuleId);
781 GlobalLinkedModuleIds.push_back(
ModuleId);
void const char * ModuleId
Definition CompilerInterfaceDevice.cpp:33
void * FatbinWrapper
Definition CompilerInterfaceDevice.cpp:32
const void const char * VarName
Definition CompilerInterfaceDevice.cpp:21
void char * KernelName
Definition CompilerInterfaceDevice.cpp:52
void * Kernel
Definition CompilerInterfaceDevice.cpp:52
const void const char uint64_t VarSize
Definition CompilerInterfaceDevice.cpp:22
ArrayRef< RuntimeConstantInfo * > RCInfoArray
Definition CompilerInterfaceHost.cpp:25
#define PROTEUS_DBG(x)
Definition Debug.h:9
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
void getLambdaJitValues(StringRef FnName, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineHost.cpp:178
#define TIMESCOPE(x)
Definition TimeTracing.hpp:64
#define PROTEUS_TIMER_OUTPUT(x)
Definition TimeTracing.hpp:57
Definition JitEngineDevice.hpp:79
FatbinWrapperT * getFatbinWrapper() const
Definition JitEngineDevice.hpp:102
void mapGlobals()
Definition JitEngineDevice.hpp:193
void setExtractedModules(SmallVector< std::unique_ptr< Module > > &Modules)
Definition JitEngineDevice.hpp:151
std::unordered_map< std::string, GlobalVarInfo > & getVarNameToGlobalInfo()
Definition JitEngineDevice.hpp:217
MemoryBufferRef getDeviceBinary()
Definition JitEngineDevice.hpp:175
bool hasModuleHash() const
Definition JitEngineDevice.hpp:155
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.hpp:104
Module & getLinkedModule()
Definition JitEngineDevice.hpp:107
auto & getModuleIds()
Definition JitEngineDevice.hpp:221
void registerGlobalVar(const char *VarName, const void *Addr, uint64_t VarSize)
Definition JitEngineDevice.hpp:188
bool hasLinkedModule() const
Definition JitEngineDevice.hpp:106
bool hasDeviceBinary()
Definition JitEngineDevice.hpp:174
const SmallVector< std::reference_wrapper< Module > > getExtractedModules() const
Definition JitEngineDevice.hpp:142
void updateModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:158
HashT getModuleHash() const
Definition JitEngineDevice.hpp:156
bool hasExtractedModules() const
Definition JitEngineDevice.hpp:140
void addModuleId(const char *ModuleId)
Definition JitEngineDevice.hpp:184
CallGraph & getCallGraph()
Definition JitEngineDevice.hpp:165
BinaryInfo(FatbinWrapperT *FatbinWrapper, SmallVector< std::string > &&LinkedModuleIds)
Definition JitEngineDevice.hpp:95
void setModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:157
void setDeviceBinary(std::unique_ptr< MemoryBuffer > DeviceBinaryBuffer)
Definition JitEngineDevice.hpp:180
Definition CompilationTask.hpp:18
static CompilerAsync & instance(int NumThreads)
Definition CompilerAsync.hpp:49
void joinAllThreads()
Definition CompilerAsync.hpp:88
std::unique_ptr< MemoryBuffer > compile(CompilationTask &&CT)
Definition CompilerSync.hpp:21
static CompilerSync & instance()
Definition CompilerSync.hpp:16
int ProteusTraceOutput
Definition Config.hpp:313
static Config & get()
Definition Config.hpp:298
bool ProteusRelinkGlobalsByCopy
Definition Config.hpp:307
bool ProteusDumpLLVMIR
Definition Config.hpp:306
bool ProteusUseStoredCache
Definition Config.hpp:304
const CodeGenerationConfig & getCGConfig(llvm::StringRef KName="") const
Definition Config.hpp:317
Definition Hashing.hpp:20
std::string toString() const
Definition Hashing.hpp:28
Definition JitEngineDevice.hpp:224
bool hasBitcode()
Definition JitEngineDevice.hpp:259
const std::string & getName() const
Definition JitEngineDevice.hpp:250
void createStaticHash(HashT ModuleHash)
Definition JitEngineDevice.hpp:267
JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.hpp:237
void * getKernel() const
Definition JitEngineDevice.hpp:245
bool hasModule() const
Definition JitEngineDevice.hpp:252
const HashT getStaticHash() const
Definition JitEngineDevice.hpp:266
BinaryInfo & getBinaryInfo() const
Definition JitEngineDevice.hpp:254
ArrayRef< RuntimeConstantInfo * > getRCInfoArray() const
Definition JitEngineDevice.hpp:251
Module & getModule() const
Definition JitEngineDevice.hpp:253
void setModule(std::unique_ptr< llvm::Module > Mod)
Definition JitEngineDevice.hpp:255
void setLambdaCalleeInfo(SmallVector< std::pair< std::string, StringRef > > &&LambdaInfo)
Definition JitEngineDevice.hpp:274
const auto & getLambdaCalleeInfo()
Definition JitEngineDevice.hpp:273
bool hasLambdaCalleeInfo()
Definition JitEngineDevice.hpp:272
MemoryBufferRef getBitcode()
Definition JitEngineDevice.hpp:263
bool hasStaticHash() const
Definition JitEngineDevice.hpp:265
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.hpp:249
void setBitcode(std::unique_ptr< MemoryBuffer > ExtractedBitcode)
Definition JitEngineDevice.hpp:260
Definition JitEngineDevice.hpp:282
MemoryBufferRef getBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:407
~JitEngineDevice()
Definition JitEngineDevice.hpp:561
typename DeviceTraits< ImplT >::DeviceError_t DeviceError_t
Definition JitEngineDevice.hpp:284
void extractModuleAndBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:368
void registerLinkedBinary(FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:769
JitEngineDevice()
Definition JitEngineDevice.hpp:559
std::unordered_map< const void *, BinaryInfo > HandleToBinaryInfo
Definition JitEngineDevice.hpp:473
DenseMap< const void *, JITKernelInfo > JITKernelInfoMap
Definition JitEngineDevice.hpp:570
typename DeviceTraits< ImplT >::DeviceStream_t DeviceStream_t
Definition JitEngineDevice.hpp:285
void registerFatBinaryEnd()
Definition JitEngineDevice.hpp:728
MemoryCache< KernelFunction_t > CodeCache
Definition JitEngineDevice.hpp:566
Module & getModule(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:397
bool containsJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:477
std::optional< std::reference_wrapper< JITKernelInfo > > getJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:482
StorageCache ObjectCache
Definition JitEngineDevice.hpp:567
std::pair< std::unique_ptr< Module >, std::unique_ptr< MemoryBuffer > > extractKernelModule(BinaryInfo &BinInfo, StringRef KernelName, LLVMContext &Ctx)
Definition JitEngineDevice.hpp:294
SmallPtrSet< void *, 8 > GlobalLinkedBinaries
Definition JitEngineDevice.hpp:475
void registerFunction(void *Handle, void *Kernel, char *KernelName, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.hpp:739
void insertRegisterVar(void *Handle, const char *VarName, const void *Addr, uint64_t VarSize)
Definition JitEngineDevice.hpp:454
void finalize()
Definition JitEngineDevice.hpp:506
void * CurHandle
Definition JitEngineDevice.hpp:471
void registerFatBinary(void *Handle, FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:697
DeviceError_t compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, typename DeviceTraits< ImplT >::DeviceStream_t Stream)
Definition JitEngineDevice.hpp:585
std::unordered_map< std::string, FatbinWrapperT * > ModuleIdToFatBinary
Definition JitEngineDevice.hpp:472
typename DeviceTraits< ImplT >::KernelFunction_t KernelFunction_t
Definition JitEngineDevice.hpp:286
SmallVector< std::string > GlobalLinkedModuleIds
Definition JitEngineDevice.hpp:474
StringRef getDeviceArch() const
Definition JitEngineDevice.hpp:512
std::string DeviceArch
Definition JitEngineDevice.hpp:568
HashT getStaticHash(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:489
void getLambdaJitValues(JITKernelInfo &KernelInfo, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineDevice.hpp:417
Definition JitEngine.hpp:34
Definition LambdaRegistry.hpp:20
std::optional< DenseMap< StringRef, SmallVector< RuntimeConstant > >::iterator > matchJitVariableMap(StringRef FnName)
Definition LambdaRegistry.hpp:28
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:22
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.hpp:25
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:19
Definition MemoryCache.hpp:27
void printStats()
Definition MemoryCache.hpp:61
Definition StorageCache.hpp:30
void printStats()
Definition StorageCache.cpp:86
Definition TimeTracing.hpp:36
Definition StorageCache.cpp:24
std::unique_ptr< Module > cloneKernelFromModules(ArrayRef< std::reference_wrapper< Module > > Mods, StringRef EntryName)
Definition Cloning.h:497
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
cudaError_t launchKernelDirect(void *KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:21
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:56
HashT hashCombine(HashT A, HashT B)
Definition Hashing.hpp:121
void pruneIR(Module &M, bool UnsetExternallyInitialized=true)
Definition CoreLLVM.hpp:253
std::string toString(CodegenOption Option)
Definition Config.hpp:26
void * resolveDeviceGlobalAddr(const void *Addr)
Definition CoreDeviceCUDA.hpp:13
void internalize(Module &M, StringRef PreserveFunctionName)
Definition CoreLLVM.hpp:288
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.hpp:28
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.hpp:230
std::unique_ptr< Module > linkModules(LLVMContext &Ctx, SmallVector< std::unique_ptr< Module > > LinkedModules)
Definition CoreLLVM.hpp:213
Definition Hashing.hpp:147
Definition JitEngineDevice.hpp:280
Definition JitEngineDevice.hpp:72
const char * Binary
Definition JitEngineDevice.hpp:75
void ** PrelinkedFatbins
Definition JitEngineDevice.hpp:76
int32_t Magic
Definition JitEngineDevice.hpp:73
int32_t Version
Definition JitEngineDevice.hpp:74
Definition GlobalVarInfo.hpp:5