11#ifndef PROTEUS_JITENGINEDEVICE_HPP
12#define PROTEUS_JITENGINEDEVICE_HPP
16#include <llvm/ADT/SmallPtrSet.h>
17#include <llvm/Analysis/CallGraph.h>
22#include <llvm/ADT/SmallVector.h>
23#include <llvm/ADT/StringRef.h>
24#include <llvm/Analysis/TargetTransformInfo.h>
25#include <llvm/Bitcode/BitcodeWriter.h>
26#include <llvm/CodeGen/CommandFlags.h>
27#include <llvm/CodeGen/MachineModuleInfo.h>
28#include <llvm/Config/llvm-config.h>
29#include <llvm/Demangle/Demangle.h>
30#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
31#include <llvm/IR/Constants.h>
32#include <llvm/IR/GlobalVariable.h>
33#include <llvm/IR/Instruction.h>
34#include <llvm/IR/Instructions.h>
35#include <llvm/IR/LLVMContext.h>
36#include <llvm/IR/LegacyPassManager.h>
37#include <llvm/IR/Module.h>
38#include <llvm/IR/ReplaceConstant.h>
39#include <llvm/IR/Type.h>
40#include <llvm/IR/Verifier.h>
41#include <llvm/IRReader/IRReader.h>
42#include <llvm/Linker/Linker.h>
43#include <llvm/MC/TargetRegistry.h>
44#include <llvm/Object/ELFObjectFile.h>
45#include <llvm/Passes/PassBuilder.h>
46#include <llvm/Support/Error.h>
47#include <llvm/Support/MemoryBuffer.h>
48#include <llvm/Support/MemoryBufferRef.h>
49#include <llvm/Target/TargetMachine.h>
50#include <llvm/Transforms/IPO/Internalize.h>
51#include <llvm/Transforms/Utils/Cloning.h>
52#include <llvm/Transforms/Utils/ModuleUtils.h>
81 SmallVector<std::string> LinkedModuleIds;
82 std::unique_ptr<Module> ExtractedModule;
83 std::optional<HashT> ExtractedModuleHash;
84 std::optional<CallGraph> ModuleCallGraph;
89 SmallVector<std::string> &&LinkedModuleIds)
91 ModuleCallGraph(
std::nullopt) {}
95 bool hasModule()
const {
return (ExtractedModule !=
nullptr); }
96 Module &
getModule()
const {
return *ExtractedModule; }
98 ExtractedModule = std::move(Module);
105 if (ExtractedModuleHash)
106 ExtractedModuleHash =
hashCombine(ExtractedModuleHash.value(), HashValue);
108 ExtractedModuleHash = HashValue;
112 if (!ModuleCallGraph.has_value()) {
113 ModuleCallGraph.emplace(CallGraph(*ExtractedModule));
115 return ModuleCallGraph.value();
119 LinkedModuleIds.push_back(
ModuleId);
126 std::optional<void *> Kernel;
128 SmallVector<int32_t> RCTypes;
129 SmallVector<int32_t> RCIndices;
130 std::optional<std::unique_ptr<Module>> ExtractedModule;
131 std::optional<std::reference_wrapper<BinaryInfo>> BinInfo;
132 std::optional<HashT> StaticHash;
133 std::optional<SmallVector<std::pair<std::string, StringRef>>>
138 int32_t *RCIndices, int32_t *RCTypes, int32_t
NumRCs)
142 ExtractedModule(
std::nullopt), BinInfo(BinInfo),
143 LambdaCalleeInfo(
std::nullopt) {}
147 assert(
Kernel.has_value() &&
"Expected Kernel is inited");
150 const std::string &
getName()
const {
return Name; }
153 bool hasModule()
const {
return ExtractedModule.has_value(); }
154 Module &
getModule()
const {
return *ExtractedModule->get(); }
157 ExtractedModule = std::move(Mod);
162 StaticHash =
hash(Name);
163 StaticHash =
hashCombine(StaticHash.value(), ModuleHash);
169 SmallVector<std::pair<std::string, StringRef>> &&LambdaInfo) {
170 LambdaCalleeInfo = std::move(LambdaInfo);
190 void **KernelArgs, uint64_t ShmemSize,
202 std::unique_ptr<Module> ExtractedModule =
203 static_cast<ImplT &
>(*this).extractModule(BinInfo);
205 pruneIR(*ExtractedModule);
208 BinInfo.
setModule(std::move(ExtractedModule));
212 std::unique_ptr<Module> KernelModule{
nullptr};
214 if (
Config.PROTEUS_USE_LIGHTWEIGHT_KERNEL_CLONE) {
215 KernelModule = std::move(proteus::cloneKernelFromModule(
219 KernelModule = llvm::CloneModule(BinModule);
222 internalize(*KernelModule, KernelInfo.
getName());
225 KernelInfo.
setModule(std::move(KernelModule));
230 SmallVector<RuntimeConstant> &LambdaJitValuesVec) {
238 Module &KernelModule =
getModule(KernelInfo);
240 <<
"=== LAMBDA MATCHING\n"
241 <<
"Caller trigger " << KernelInfo.
getName() <<
" -> "
242 << demangle(KernelInfo.
getName()) <<
"\n");
244 SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo;
245 for (
auto &F : KernelModule.getFunctionList()) {
247 <<
" Trying F " << demangle(F.getName().str()) <<
"\n ");
251 LambdaCalleeInfo.emplace_back(F.getName(),
252 OptionalMapIt.value()->first);
259 const SmallVector<RuntimeConstant> &Values =
261 LambdaJitValuesVec.insert(LambdaJitValuesVec.end(), Values.begin(),
287 std::optional<std::reference_wrapper<JITKernelInfo>>
306 HashT ModuleHash =
static_cast<ImplT &
>(*this).getModuleHash(BinInfo);
313 if (
Config.PROTEUS_ASYNC_COMPILATION)
321 void *resolveDeviceGlobalAddr(
const void *Addr) {
322 return static_cast<ImplT &
>(*this).resolveDeviceGlobalAddr(Addr);
325 void setLaunchBoundsForKernel(Module &M, Function &F,
size_t GridSize,
327 static_cast<ImplT &
>(*this).setLaunchBoundsForKernel(M, F, GridSize,
331 void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
332 proteus::setKernelDims(M, GridDim, BlockDim);
336 dim3 BlockDim,
void **KernelArgs,
340 return static_cast<ImplT &
>(*this).launchKernelFunction(
341 KernelFunc, GridDim, BlockDim, KernelArgs, ShmemSize, Stream);
344 void relinkGlobalsObject(MemoryBufferRef Object) {
349 std::unique_ptr<MemoryBuffer> codegenObject(Module &M, StringRef
DeviceArch) {
350 return static_cast<ImplT &
>(*this).codegenObject(M,
DeviceArch);
356 return static_cast<ImplT &
>(*this).getKernelFunctionFromImage(
KernelName,
364 void pruneIR(Module &M);
366 void internalize(Module &M, StringRef
KernelName);
368 void specializeIR(Module &M, StringRef FnName, StringRef Suffix,
369 dim3 &BlockDim, dim3 &GridDim,
371 const SmallVector<RuntimeConstant> &RCVec);
373 void replaceGlobalVariablesWithPointers(Module &M);
387 std::unique_ptr<Module>
389 std::unique_ptr<Module> LTOModule =
nullptr);
396template <
typename ImplT>
398 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
399 dim3 &GridDim,
const SmallVector<int32_t> &
RCIndices,
400 const SmallVector<RuntimeConstant> &RCVec) {
403 proteus::specializeIR(M, FnName, Suffix, BlockDim, GridDim,
RCIndices, RCVec,
404 Config.PROTEUS_SPECIALIZE_ARGS,
405 Config.PROTEUS_SPECIALIZE_DIMS,
406 Config.PROTEUS_SET_LAUNCH_BOUNDS);
408#if PROTEUS_ENABLE_DEBUG
410 << M <<
"=== End Final Module\n";
411 if (verifyModule(M, &errs()))
418template <
typename ImplT>
void JitEngineDevice<ImplT>::pruneIR(Module &M) {
423template <
typename ImplT>
424void JitEngineDevice<ImplT>::internalize(Module &M, StringRef
KernelName) {
428template <
typename ImplT>
429void JitEngineDevice<ImplT>::replaceGlobalVariablesWithPointers(Module &M) {
432 proteus::replaceGlobalVariablesWithPointers(M, VarNameToDevPtr);
434#if PROTEUS_ENABLE_DEBUG
435 Logger::logs(
"proteus") <<
"=== Linked M\n" << M <<
"=== End of Linked M\n";
436 if (verifyModule(M, &errs()))
438 "After linking, broken module found, JIT compilation aborted!");
444template <
typename ImplT>
445typename DeviceTraits<ImplT>::DeviceError_t
447 JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim,
void **KernelArgs,
455 static std::once_flag Flag;
456 std::call_once(Flag, [&]() {
457 for (
auto &[GlobalName, HostAddr] : VarNameToDevPtr) {
459 VarNameToDevPtr.at(GlobalName) = DevPtr;
463 SmallVector<RuntimeConstant> RCVec;
464 SmallVector<RuntimeConstant> LambdaJitValuesVec;
466 getRuntimeConstantValues(KernelArgs, KernelInfo.
getRCIndices(),
471 hash(getStaticHash(KernelInfo), RCVec, LambdaJitValuesVec, GridDim.x,
472 GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z);
475 CodeCache.lookup(HashValue);
483 std::string Suffix = mangleSuffix(HashValue);
484 std::string KernelMangled = (KernelInfo.
getName() + Suffix);
486 if (Config.PROTEUS_USE_STORED_CACHE) {
496 auto CacheBuf = StorageCache.lookup(HashValue);
498 if (!Config.PROTEUS_RELINK_GLOBALS_BY_COPY)
499 relinkGlobalsObject(CacheBuf->getMemBufferRef());
504 CodeCache.insert(HashValue, KernelFunc, KernelInfo.
getName(), RCVec);
511 Module &KernelModule = getModule(KernelInfo);
512 std::unique_ptr<MemoryBuffer> ObjBuf =
nullptr;
514 if (Config.PROTEUS_ASYNC_COMPILATION) {
518 if (!Compiler.isCompilationPending(HashValue)) {
523 KernelModule, HashValue, KernelInfo.
getName(), Suffix, BlockDim,
526 GlobalLinkedBinaries, DeviceArch,
527 Config.PROTEUS_USE_HIP_RTC_CODEGEN,
528 Config.PROTEUS_DUMP_LLVM_IR,
529 Config.PROTEUS_RELINK_GLOBALS_BY_COPY,
530 Config.PROTEUS_SPECIALIZE_ARGS,
531 Config.PROTEUS_SPECIALIZE_DIMS,
532 Config.PROTEUS_SET_LAUNCH_BOUNDS});
538 ObjBuf = Compiler.takeCompilationResult(HashValue,
539 Config.PROTEUS_ASYNC_TEST_BLOCKING);
542 KernelArgs, ShmemSize, Stream);
547 KernelModule, HashValue, KernelInfo.
getName(), Suffix, BlockDim,
551 Config.PROTEUS_USE_HIP_RTC_CODEGEN,
552 Config.PROTEUS_DUMP_LLVM_IR,
553 Config.PROTEUS_RELINK_GLOBALS_BY_COPY,
554 Config.PROTEUS_SPECIALIZE_ARGS,
555 Config.PROTEUS_SPECIALIZE_DIMS,
556 Config.PROTEUS_SET_LAUNCH_BOUNDS});
560 KernelMangled, ObjBuf->getBufferStart(),
561 Config.PROTEUS_RELINK_GLOBALS_BY_COPY, VarNameToDevPtr);
563 CodeCache.insert(HashValue, KernelFunc, KernelInfo.
getName(), RCVec);
564 if (Config.PROTEUS_USE_STORED_CACHE) {
565 StorageCache.store(HashValue, ObjBuf->getMemBufferRef());
572template <
typename ImplT>
578 <<
"Register fatbinary Handle " << Handle <<
" FatbinWrapper "
580 <<
" ModuleId " <<
ModuleId <<
"\n");
588 for (
int I = 0; Ptr !=
nullptr;
591 <<
"I " << I <<
" PrelinkedFatbin " << Ptr <<
"\n");
592 GlobalLinkedBinaries.insert(Ptr);
612template <
typename ImplT>
619 <<
" To Handle " << Handle <<
"\n");
624 if (JITKernelInfoMap.contains(
Kernel)) {
626 <<
"Warning: duplicate register function for kernel " +
632 if (!HandleToBinaryInfo.count(Handle))
634 BinaryInfo &BinInfo = HandleToBinaryInfo[Handle];
636 JITKernelInfoMap[
Kernel] =
640template <
typename ImplT>
645 <<
" Binary " << (
void *)
FatbinWrapper->Binary <<
" ModuleId "
648 if (!HandleToBinaryInfo.count(CurHandle))
651 HandleToBinaryInfo[CurHandle].addModuleId(
ModuleId);
653 GlobalLinkedModuleIds.push_back(
ModuleId);
658template <
typename ImplT>
660 SmallVector<std::unique_ptr<Module>> &LinkedModules,
661 std::unique_ptr<Module> LTOModule) {
662 if (LinkedModules.empty())
672 Linker IRLinker(*LinkedModule);
675 for (
auto &F : *LTOModule) {
676 if (F.hasInternalLinkage())
677 F.setLinkage(GlobalValue::ExternalLinkage);
680 if (IRLinker.linkInModule(std::move(LTOModule),
681 Linker::Flags::LinkOnlyNeeded))
void const char * ModuleId
Definition CompilerInterfaceDevice.cpp:31
void * FatbinWrapper
Definition CompilerInterfaceDevice.cpp:30
void char int32_t int32_t * RCTypes
Definition CompilerInterfaceDevice.cpp:51
const char * VarName
Definition CompilerInterfaceDevice.cpp:20
void char int32_t * RCIndices
Definition CompilerInterfaceDevice.cpp:51
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
void char int32_t int32_t int32_t NumRCs
Definition CompilerInterfaceDevice.cpp:51
void * Kernel
Definition CompilerInterfaceDevice.cpp:50
#define PROTEUS_DBG(x)
Definition Debug.h:7
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
void getLambdaJitValues(Module &M, StringRef FnName, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineHost.cpp:273
#define TIMESCOPE(x)
Definition TimeTracing.hpp:35
Definition JitEngineDevice.hpp:78
FatbinWrapperT * getFatbinWrapper() const
Definition JitEngineDevice.hpp:93
bool hasModuleHash() const
Definition JitEngineDevice.hpp:101
void setModule(std::unique_ptr< Module > Module)
Definition JitEngineDevice.hpp:97
bool hasModule() const
Definition JitEngineDevice.hpp:95
auto & getModuleIds()
Definition JitEngineDevice.hpp:122
Module & getModule() const
Definition JitEngineDevice.hpp:96
void updateModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:104
HashT getModuleHash() const
Definition JitEngineDevice.hpp:102
void addModuleId(const char *ModuleId)
Definition JitEngineDevice.hpp:118
CallGraph & getCallGraph()
Definition JitEngineDevice.hpp:111
BinaryInfo(FatbinWrapperT *FatbinWrapper, SmallVector< std::string > &&LinkedModuleIds)
Definition JitEngineDevice.hpp:88
void setModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:103
Definition CompilationTask.hpp:17
static CompilerAsync & instance(int NumThreads)
Definition CompilerAsync.hpp:49
void joinAllThreads()
Definition CompilerAsync.hpp:86
std::unique_ptr< MemoryBuffer > compile(CompilationTask &&CT)
Definition CompilerSync.hpp:21
static CompilerSync & instance()
Definition CompilerSync.hpp:16
Definition Hashing.hpp:19
std::string toString() const
Definition Hashing.hpp:27
Definition JitEngineDevice.hpp:125
const std::string & getName() const
Definition JitEngineDevice.hpp:150
void createStaticHash(HashT ModuleHash)
Definition JitEngineDevice.hpp:161
void * getKernel() const
Definition JitEngineDevice.hpp:146
bool hasModule() const
Definition JitEngineDevice.hpp:153
JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name, int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs)
Definition JitEngineDevice.hpp:137
const auto & getRCTypes() const
Definition JitEngineDevice.hpp:152
const HashT getStaticHash() const
Definition JitEngineDevice.hpp:160
BinaryInfo & getBinaryInfo() const
Definition JitEngineDevice.hpp:155
Module & getModule() const
Definition JitEngineDevice.hpp:154
const auto & getRCIndices() const
Definition JitEngineDevice.hpp:151
void setModule(std::unique_ptr< llvm::Module > Mod)
Definition JitEngineDevice.hpp:156
void setLambdaCalleeInfo(SmallVector< std::pair< std::string, StringRef > > &&LambdaInfo)
Definition JitEngineDevice.hpp:168
const auto & getLambdaCalleeInfo()
Definition JitEngineDevice.hpp:167
bool hasLambdaCalleeInfo()
Definition JitEngineDevice.hpp:166
bool hasStaticHash() const
Definition JitEngineDevice.hpp:159
Definition JitCache.hpp:32
void printStats()
Definition JitCache.hpp:64
Definition JitEngineDevice.hpp:176
JitCache< KernelFunction_t > CodeCache
Definition JitEngineDevice.hpp:383
~JitEngineDevice()
Definition JitEngineDevice.hpp:378
LLVMContext & getLLVMContext()
Definition JitEngineDevice.hpp:391
typename DeviceTraits< ImplT >::DeviceError_t DeviceError_t
Definition JitEngineDevice.hpp:184
void registerLinkedBinary(FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:641
JitEngineDevice()
Definition JitEngineDevice.hpp:376
std::unordered_map< const void *, BinaryInfo > HandleToBinaryInfo
Definition JitEngineDevice.hpp:279
DenseMap< const void *, JITKernelInfo > JITKernelInfoMap
Definition JitEngineDevice.hpp:393
typename DeviceTraits< ImplT >::DeviceStream_t DeviceStream_t
Definition JitEngineDevice.hpp:185
void registerFatBinaryEnd()
Definition JitEngineDevice.hpp:602
std::unique_ptr< Module > linkJitModule(SmallVector< std::unique_ptr< Module > > &LinkedModules, std::unique_ptr< Module > LTOModule=nullptr)
Definition JitEngineDevice.hpp:659
Module & getModule(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:193
JitStorageCache< KernelFunction_t > StorageCache
Definition JitEngineDevice.hpp:384
bool containsJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:283
std::optional< std::reference_wrapper< JITKernelInfo > > getJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:288
SmallPtrSet< void *, 8 > GlobalLinkedBinaries
Definition JitEngineDevice.hpp:281
std::unordered_map< std::string, const void * > VarNameToDevPtr
Definition JitEngineDevice.hpp:386
void registerFunction(void *Handle, void *Kernel, char *KernelName, int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs)
Definition JitEngineDevice.hpp:613
void finalize()
Definition JitEngineDevice.hpp:312
void * CurHandle
Definition JitEngineDevice.hpp:277
void registerFatBinary(void *Handle, FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:573
DeviceError_t compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, typename DeviceTraits< ImplT >::DeviceStream_t Stream)
Definition JitEngineDevice.hpp:446
std::unordered_map< std::string, FatbinWrapperT * > ModuleIdToFatBinary
Definition JitEngineDevice.hpp:278
typename DeviceTraits< ImplT >::KernelFunction_t KernelFunction_t
Definition JitEngineDevice.hpp:186
void insertRegisterVar(const char *VarName, const void *Addr)
Definition JitEngineDevice.hpp:266
SmallVector< std::string > GlobalLinkedModuleIds
Definition JitEngineDevice.hpp:280
std::string DeviceArch
Definition JitEngineDevice.hpp:385
HashT getStaticHash(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:295
void getLambdaJitValues(JITKernelInfo &KernelInfo, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineDevice.hpp:229
Definition JitEngine.hpp:44
struct proteus::JitEngine::@0 Config
void runCleanupPassPipeline(Module &M)
Definition JitEngine.cpp:76
Definition JitStorageCache.hpp:36
void printStats()
Definition JitStorageCache.hpp:65
Definition LambdaRegistry.hpp:19
std::optional< DenseMap< StringRef, SmallVector< RuntimeConstant > >::iterator > matchJitVariableMap(StringRef FnName)
Definition LambdaRegistry.hpp:27
const SmallVector< RuntimeConstant > & getJitVariables(StringRef LambdaTypeRef)
Definition LambdaRegistry.hpp:74
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:21
bool empty()
Definition LambdaRegistry.hpp:78
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:18
Definition JitEngine.cpp:20
std::unique_ptr< Module > linkModules(LLVMContext &Ctx, SmallVector< std::unique_ptr< Module > > &LinkedModules)
Definition CoreLLVM.hpp:153
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:73
cudaError_t launchKernelDirect(void *KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:20
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:51
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, const void * > &VarNameToDevPtr)
Definition CoreDeviceCUDA.hpp:27
HashT hashCombine(HashT A, HashT B)
Definition Hashing.hpp:68
void pruneIR(Module &M, bool UnsetExternallyInitialized=true)
Definition CoreLLVM.hpp:193
void * resolveDeviceGlobalAddr(const void *Addr)
Definition CoreDeviceCUDA.hpp:12
void internalize(Module &M, StringRef PreserveFunctionName)
Definition CoreLLVM.hpp:228
Definition Hashing.hpp:94
Definition JitEngineDevice.hpp:174
Definition JitEngineDevice.hpp:71
const char * Binary
Definition JitEngineDevice.hpp:74
void ** PrelinkedFatbins
Definition JitEngineDevice.hpp:75
int32_t Magic
Definition JitEngineDevice.hpp:72
int32_t Version
Definition JitEngineDevice.hpp:73