Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitEngineDevice.hpp
Go to the documentation of this file.
1//===-- JitEngineDevice.cpp -- Base JIT Engine Device header impl. --===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef PROTEUS_JITENGINEDEVICE_HPP
12#define PROTEUS_JITENGINEDEVICE_HPP
13
14#include <cstdint>
15#include <functional>
16#include <llvm/ADT/SmallPtrSet.h>
17#include <llvm/Analysis/CallGraph.h>
18#include <memory>
19#include <optional>
20#include <string>
21
22#include <llvm/ADT/SmallVector.h>
23#include <llvm/ADT/StringRef.h>
24#include <llvm/Analysis/TargetTransformInfo.h>
25#include <llvm/Bitcode/BitcodeWriter.h>
26#include <llvm/CodeGen/CommandFlags.h>
27#include <llvm/CodeGen/MachineModuleInfo.h>
28#include <llvm/Config/llvm-config.h>
29#include <llvm/Demangle/Demangle.h>
30#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
31#include <llvm/IR/Constants.h>
32#include <llvm/IR/GlobalVariable.h>
33#include <llvm/IR/Instruction.h>
34#include <llvm/IR/Instructions.h>
35#include <llvm/IR/LLVMContext.h>
36#include <llvm/IR/LegacyPassManager.h>
37#include <llvm/IR/Module.h>
38#include <llvm/IR/ReplaceConstant.h>
39#include <llvm/IR/Type.h>
40#include <llvm/IR/Verifier.h>
41#include <llvm/IRReader/IRReader.h>
42#include <llvm/Linker/Linker.h>
43#include <llvm/MC/TargetRegistry.h>
44#include <llvm/Object/ELFObjectFile.h>
45#include <llvm/Passes/PassBuilder.h>
46#include <llvm/Support/Error.h>
47#include <llvm/Support/MemoryBuffer.h>
48#include <llvm/Support/MemoryBufferRef.h>
49#include <llvm/Target/TargetMachine.h>
50#include <llvm/Transforms/IPO/Internalize.h>
51#include <llvm/Transforms/Utils/Cloning.h>
52#include <llvm/Transforms/Utils/ModuleUtils.h>
53
56#include "proteus/Cloning.h"
61#include "proteus/CoreLLVM.hpp"
62#include "proteus/Debug.h"
63#include "proteus/Hashing.hpp"
64#include "proteus/JitEngine.hpp"
66#include "proteus/Utils.h"
67
68namespace proteus {
69
70using namespace llvm;
71
78
80private:
81 FatbinWrapperT *FatbinWrapper;
82 std::unique_ptr<LLVMContext> Ctx;
83 SmallVector<std::string> LinkedModuleIds;
84 Module *LinkedModule;
85 std::optional<SmallVector<std::unique_ptr<Module>>> ExtractedModules;
86 std::optional<HashT> ExtractedModuleHash;
87 std::optional<CallGraph> ModuleCallGraph;
88 std::unique_ptr<MemoryBuffer> DeviceBinary;
89 std::unordered_map<std::string, GlobalVarInfo> VarNameToGlobalInfo;
90 bool GlobalsMapped;
91 std::once_flag Flag;
92
93public:
94 BinaryInfo() = default;
95 BinaryInfo(FatbinWrapperT *FatbinWrapper,
96 SmallVector<std::string> &&LinkedModuleIds)
98 LinkedModuleIds(LinkedModuleIds), LinkedModule(nullptr),
99 ExtractedModules(std::nullopt), ModuleCallGraph(std::nullopt),
100 DeviceBinary(nullptr), GlobalsMapped(false) {}
101
103
104 std::unique_ptr<LLVMContext> &getLLVMContext() { return Ctx; }
105
106 bool hasLinkedModule() const { return (LinkedModule != nullptr); }
108 if (!LinkedModule) {
109 if (!hasExtractedModules())
110 PROTEUS_FATAL_ERROR("Expected extracted modules");
111
112 Timer T;
113 // Avoid linking when there's a single module by moving it instead and
114 // making sure it's materialized for call graph analysis.
115 if (ExtractedModules->size() == 1) {
116 LinkedModule = ExtractedModules->front().get();
117 if (auto E = LinkedModule->materializeAll())
118 PROTEUS_FATAL_ERROR("Error materializing " + toString(std::move(E)));
119 } else {
120 // By the LLVM API, linkModules takes ownership of module pointers in
121 // ExtractedModules and returns a new unique ptr to the linked module.
122 // We update ExtractedModules to contain and own only the generated
123 // LinkedModule.
125 proteus::linkModules(*Ctx, std::move(ExtractedModules.value()));
127 NewExtractedModules.emplace_back(std::move(GeneratedLinkedModule));
129
130 LinkedModule = ExtractedModules->front().get();
131 }
132
134 << "getLinkedModule " << T.elapsed() << " ms\n");
135 }
136
137 return *LinkedModule;
138 }
139
140 bool hasExtractedModules() const { return ExtractedModules.has_value(); }
143 // This should be called only once when cloning the kernel module to
144 // cache.
146 for (auto &M : ExtractedModules.value())
147 ModulesRef.emplace_back(*M);
148
149 return ModulesRef;
150 }
151 void setExtractedModules(SmallVector<std::unique_ptr<Module>> &Modules) {
152 ExtractedModules = std::move(Modules);
153 }
154
155 bool hasModuleHash() const { return ExtractedModuleHash.has_value(); }
156 HashT getModuleHash() const { return ExtractedModuleHash.value(); }
157 void setModuleHash(HashT HashValue) { ExtractedModuleHash = HashValue; }
158 void updateModuleHash(HashT HashValue) {
159 if (ExtractedModuleHash)
160 ExtractedModuleHash = hashCombine(ExtractedModuleHash.value(), HashValue);
161 else
162 ExtractedModuleHash = HashValue;
163 }
164
166 if (!ModuleCallGraph.has_value()) {
167 if (!LinkedModule)
168 PROTEUS_FATAL_ERROR("Expected non-null linked module");
169 ModuleCallGraph.emplace(CallGraph(*LinkedModule));
170 }
171 return ModuleCallGraph.value();
172 }
173
174 bool hasDeviceBinary() { return (DeviceBinary != nullptr); }
176 if (!hasDeviceBinary())
177 PROTEUS_FATAL_ERROR("Expeced non-null device binary");
178 return DeviceBinary->getMemBufferRef();
179 }
180 void setDeviceBinary(std::unique_ptr<MemoryBuffer> DeviceBinaryBuffer) {
181 DeviceBinary = std::move(DeviceBinaryBuffer);
182 }
183
184 void addModuleId(const char *ModuleId) {
185 LinkedModuleIds.push_back(ModuleId);
186 }
187
188 void registerGlobalVar(const char *VarName, const void *Addr,
190 VarNameToGlobalInfo.emplace(VarName, GlobalVarInfo(Addr, nullptr, VarSize));
191 }
192
193 void mapGlobals() {
194 std::call_once(Flag, [&]() {
195 for (auto &[GlobalName, GVI] : VarNameToGlobalInfo) {
196 void *DevPtr = resolveDeviceGlobalAddr(GVI.HostAddr);
197 VarNameToGlobalInfo.at(GlobalName).DevAddr = DevPtr;
198 }
199 auto TraceOut = [](std::unordered_map<std::string, GlobalVarInfo>
200 &VarNameToGlobalInfo) {
203 for (auto &[GlobalName, GVI] : VarNameToGlobalInfo) {
204 OS << "[GVarInfo]: " << GlobalName << " HAddr:" << GVI.HostAddr
205 << " DevAddr:" << GVI.DevAddr << " VarSize:" << GVI.VarSize
206 << "\n";
207 }
208
209 return S;
210 };
212 Logger::trace(TraceOut(VarNameToGlobalInfo));
213 GlobalsMapped = true;
214 });
215 }
216
217 std::unordered_map<std::string, GlobalVarInfo> &getVarNameToGlobalInfo() {
218 return VarNameToGlobalInfo;
219 }
220
221 auto &getModuleIds() { return LinkedModuleIds; }
222};
223
225 std::optional<void *> Kernel;
226 std::unique_ptr<LLVMContext> Ctx;
227 std::string Name;
229 std::optional<std::unique_ptr<Module>> ExtractedModule;
230 std::optional<std::unique_ptr<MemoryBuffer>> Bitcode;
231 std::optional<std::reference_wrapper<BinaryInfo>> BinInfo;
232 std::optional<HashT> StaticHash;
233 std::optional<SmallVector<std::pair<std::string, StringRef>>>
234 LambdaCalleeInfo;
235
236public:
237 JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name,
239 : Kernel(Kernel), Ctx(std::make_unique<LLVMContext>()), Name(Name),
240 RCInfoArray(RCInfoArray), ExtractedModule(std::nullopt),
241 Bitcode{std::nullopt}, BinInfo(BinInfo),
242 LambdaCalleeInfo(std::nullopt) {}
243
244 JITKernelInfo() = default;
245 void *getKernel() const {
246 assert(Kernel.has_value() && "Expected Kernel is inited");
247 return Kernel.value();
248 }
249 std::unique_ptr<LLVMContext> &getLLVMContext() { return Ctx; }
250 const std::string &getName() const { return Name; }
252 bool hasModule() const { return ExtractedModule.has_value(); }
253 Module &getModule() const { return *ExtractedModule->get(); }
254 BinaryInfo &getBinaryInfo() const { return BinInfo.value(); }
255 void setModule(std::unique_ptr<llvm::Module> Mod) {
256 ExtractedModule = std::move(Mod);
257 }
258
259 bool hasBitcode() { return Bitcode.has_value(); }
260 void setBitcode(std::unique_ptr<MemoryBuffer> ExtractedBitcode) {
261 Bitcode = std::move(ExtractedBitcode);
262 }
263 MemoryBufferRef getBitcode() { return Bitcode.value()->getMemBufferRef(); }
264
265 bool hasStaticHash() const { return StaticHash.has_value(); }
266 const HashT getStaticHash() const { return StaticHash.value(); }
267 void createStaticHash(HashT ModuleHash) {
268 StaticHash = hash(Name);
269 StaticHash = hashCombine(StaticHash.value(), ModuleHash);
270 }
271
272 bool hasLambdaCalleeInfo() { return LambdaCalleeInfo.has_value(); }
273 const auto &getLambdaCalleeInfo() { return LambdaCalleeInfo.value(); }
275 SmallVector<std::pair<std::string, StringRef>> &&LambdaInfo) {
276 LambdaCalleeInfo = std::move(LambdaInfo);
277 }
278};
279
280template <typename ImplT> struct DeviceTraits;
281
282template <typename ImplT> class JitEngineDevice : public JitEngine {
283public:
287
292
293 std::pair<std::unique_ptr<Module>, std::unique_ptr<MemoryBuffer>>
295 LLVMContext &Ctx) {
296 std::unique_ptr<Module> KernelModule =
297 static_cast<ImplT &>(*this).tryExtractKernelModule(BinInfo, KernelName,
298 Ctx);
299 std::unique_ptr<MemoryBuffer> Bitcode = nullptr;
300
301 // If there is no ready-made kernel module from AOT, extract per-TU or the
302 // single linked module and clone the kernel module.
303 if (!KernelModule) {
304 Timer T;
305 if (!BinInfo.hasExtractedModules())
306 static_cast<ImplT &>(*this).extractModules(BinInfo);
307
308 std::unique_ptr<Module> KernelModuleTmp = nullptr;
309 switch (Config::get().ProteusKernelClone) {
311 auto &LinkedModule = BinInfo.getLinkedModule();
312 KernelModule = llvm::CloneModule(LinkedModule);
313 break;
314 }
316 auto &LinkedModule = BinInfo.getLinkedModule();
319 break;
320 }
324 break;
325 }
326 default:
327 PROTEUS_FATAL_ERROR("Unsupported kernel cloning option");
328 }
329
331 << "Cloning "
332 << toString(Config::get().ProteusKernelClone) << " "
333 << T.elapsed() << " ms\n");
334 }
335
336 // Internalize and cleanup to simplify the module and prepare it for
337 // optimization.
338 internalize(*KernelModule, KernelName);
340
341 // If the module is not in the provided context due to cloning, roundtrip
342 // it using bitcode. Re-use the roundtrip bitcode to return it.
343 if (&KernelModule->getContext() != &Ctx) {
350 if (auto E = ExpectedKernelModule.takeError())
351 PROTEUS_FATAL_ERROR("Error parsing bitcode: " + toString(std::move(E)));
352
354 Bitcode = MemoryBuffer::getMemBufferCopy(CloneStr);
355 } else {
356 // Parse the kernel module to create the bitcode since it has not been
357 // created by roundtripping.
361 auto BitcodeStr = StringRef{BitcodeBuffer.data(), BitcodeBuffer.size()};
362 Bitcode = MemoryBuffer::getMemBufferCopy(BitcodeStr);
363 }
364
365 return std::make_pair(std::move(KernelModule), std::move(Bitcode));
366 }
367
370
371 if (KernelInfo.hasModule() && KernelInfo.hasBitcode())
372 return;
373
374 if (KernelInfo.hasModule())
375 PROTEUS_FATAL_ERROR("Unexpected KernelInfo has module but not bitcode");
376
377 if (KernelInfo.hasBitcode())
378 PROTEUS_FATAL_ERROR("Unexpected KernelInfo has bitcode but not module");
379
380 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
381
382 Timer T;
384 BinInfo, KernelInfo.getName(), *KernelInfo.getLLVMContext());
385
386 if (!KernelModule)
387 PROTEUS_FATAL_ERROR("Expected non-null kernel module");
388 if (!BitcodeBuffer)
389 PROTEUS_FATAL_ERROR("Expected non-null kernel bitcode");
390
391 KernelInfo.setModule(std::move(KernelModule));
392 KernelInfo.setBitcode(std::move(BitcodeBuffer));
394 << "Extract kernel module " << T.elapsed() << " ms\n");
395 }
396
398 if (!KernelInfo.hasModule())
400
401 if (!KernelInfo.hasModule())
402 PROTEUS_FATAL_ERROR("Expected module in KernelInfo");
403
404 return KernelInfo.getModule();
405 }
406
408 if (!KernelInfo.hasBitcode())
410
411 if (!KernelInfo.hasBitcode())
412 PROTEUS_FATAL_ERROR("Expected bitcode in KernelInfo");
413
414 return KernelInfo.getBitcode();
415 }
416
420 if (LR.empty()) {
421 KernelInfo.setLambdaCalleeInfo({});
422 return;
423 }
424
425 if (!KernelInfo.hasLambdaCalleeInfo()) {
427 PROTEUS_DBG(Logger::logs("proteus")
428 << "=== LAMBDA MATCHING\n"
429 << "Caller trigger " << KernelInfo.getName() << " -> "
430 << demangle(KernelInfo.getName()) << "\n");
431
433 for (auto &F : KernelModule.getFunctionList()) {
434 PROTEUS_DBG(Logger::logs("proteus")
435 << " Trying F " << demangle(F.getName().str()) << "\n ");
436 auto OptionalMapIt =
438 if (OptionalMapIt)
439 LambdaCalleeInfo.emplace_back(F.getName(),
440 OptionalMapIt.value()->first);
441 }
442
443 KernelInfo.setLambdaCalleeInfo(std::move(LambdaCalleeInfo));
444 }
445
446 for (auto &[FnName, LambdaType] : KernelInfo.getLambdaCalleeInfo()) {
448 LR.getJitVariables(LambdaType);
449 LambdaJitValuesVec.insert(LambdaJitValuesVec.end(), Values.begin(),
450 Values.end());
451 }
452 }
453
454 void insertRegisterVar(void *Handle, const char *VarName, const void *Addr,
456 if (!HandleToBinaryInfo.count(Handle))
457 PROTEUS_FATAL_ERROR("Expected Handle in map");
459
461 }
462
464 const char *ModuleId);
466 const char *ModuleId);
468 void registerFunction(void *Handle, void *Kernel, char *KernelName,
470
471 void *CurHandle = nullptr;
472 std::unordered_map<std::string, FatbinWrapperT *> ModuleIdToFatBinary;
473 std::unordered_map<const void *, BinaryInfo> HandleToBinaryInfo;
476
477 bool containsJITKernelInfo(const void *Func) {
478 return JITKernelInfoMap.contains(Func);
479 }
480
481 std::optional<std::reference_wrapper<JITKernelInfo>>
482 getJITKernelInfo(const void *Func) {
484 return std::nullopt;
485 }
486 return JITKernelInfoMap[Func];
487 }
488
490 if (KernelInfo.hasStaticHash())
491 return KernelInfo.getStaticHash();
492
493 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
494
495 if (BinInfo.hasModuleHash()) {
496 KernelInfo.createStaticHash(BinInfo.getModuleHash());
497 return KernelInfo.getStaticHash();
498 }
499
500 HashT ModuleHash = static_cast<ImplT &>(*this).getModuleHash(BinInfo);
501
502 KernelInfo.createStaticHash(BinInfo.getModuleHash());
503 return KernelInfo.getStaticHash();
504 }
505
506 void finalize() {
507 if (Config::get().ProteusAsyncCompilation)
508 CompilerAsync::instance(Config::get().ProteusAsyncThreads)
510 }
511
513
514private:
515 //------------------------------------------------------------------
516 // Begin Methods implemented in the derived device engine class.
517 //------------------------------------------------------------------
518 void *resolveDeviceGlobalAddr(const void *Addr) {
519 return static_cast<ImplT &>(*this).resolveDeviceGlobalAddr(Addr);
520 }
521
522 void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
523 proteus::setKernelDims(M, GridDim, BlockDim);
524 }
525
526 DeviceError_t launchKernelFunction(KernelFunction_t KernelFunc, dim3 GridDim,
527 dim3 BlockDim, void **KernelArgs,
531 return static_cast<ImplT &>(*this).launchKernelFunction(
532 KernelFunc, GridDim, BlockDim, KernelArgs, ShmemSize, Stream);
533 }
534
535 void relinkGlobalsObject(MemoryBufferRef Object,
536 const std::unordered_map<std::string, GlobalVarInfo>
537 &VarNameToGlobalInfo) {
539 proteus::relinkGlobalsObject(Object, VarNameToGlobalInfo);
540 }
541
542 KernelFunction_t getKernelFunctionFromImage(
543 StringRef KernelName, const void *Image,
544 std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo) {
546 return static_cast<ImplT &>(*this).getKernelFunctionFromImage(
547 KernelName, Image, VarNameToGlobalInfo);
548 }
549
550 //------------------------------------------------------------------
551 // End Methods implemented in the derived device engine class.
552 //------------------------------------------------------------------
553
554 void pruneIR(Module &M);
555
556 void internalize(Module &M, StringRef KernelName);
557
558protected:
560
565
567 StorageCache ObjectCache{"JitEngineDevice"};
568 std::string DeviceArch;
569
571};
572
573template <typename ImplT> void JitEngineDevice<ImplT>::pruneIR(Module &M) {
574 TIMESCOPE("pruneIR");
576}
577
578template <typename ImplT>
581}
582
583template <typename ImplT>
586 JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs,
588 TIMESCOPE("compileAndRun");
589
590 auto &BinInfo = KernelInfo.getBinaryInfo();
591
592 // Lazy initialize the map of device global variables to device pointers by
593 // resolving the host address to the device address. For HIP it is fine to
594 // do this earlier (e.g., instertRegisterVar), but CUDA can't. So, we
595 // initialize this here the first time we need to compile a kernel.
596 BinInfo.mapGlobals();
597
599 getRuntimeConstantValues(KernelArgs, KernelInfo.getRCInfoArray());
600
603
604 HashT HashValue =
605 hash(getStaticHash(KernelInfo), RCVec, LambdaJitValuesVec, GridDim.x,
606 GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z);
607
609 CodeCache.lookup(HashValue);
610 if (KernelFunc)
611 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
613
614 // NOTE: we don't need a suffix to differentiate kernels, each
615 // specialization will be in its own module uniquely identify by HashValue.
616 // It exists only for debugging purposes to verify that the jitted kernel
617 // executes.
618 std::string Suffix = mangleSuffix(HashValue);
619 std::string KernelMangled = (KernelInfo.getName() + Suffix);
620
622 auto CompiledLib = ObjectCache.lookup(HashValue);
623 if (CompiledLib) {
625 relinkGlobalsObject(CompiledLib->ObjectModule->getMemBufferRef(),
626 BinInfo.getVarNameToGlobalInfo());
627
629 KernelMangled, CompiledLib->ObjectModule->getBufferStart(),
630 BinInfo.getVarNameToGlobalInfo());
631
632 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName());
633
634 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
636 }
637 }
638
640 std::unique_ptr<MemoryBuffer> ObjBuf = nullptr;
641
642 if (Config::get().ProteusAsyncCompilation) {
643 auto &Compiler = CompilerAsync::instance(Config::get().ProteusAsyncThreads);
644 // If there is no compilation pending for the specialization, post the
645 // compilation task to the compiler.
646 if (!Compiler.isCompilationPending(HashValue)) {
647 PROTEUS_DBG(Logger::logs("proteus") << "Compile async for HashValue "
648 << HashValue.toString() << "\n");
649
651 KernelBitcode, HashValue, KernelInfo.getName(), Suffix, BlockDim,
652 GridDim, RCVec, KernelInfo.getLambdaCalleeInfo(),
653 BinInfo.getVarNameToGlobalInfo(), GlobalLinkedBinaries, DeviceArch,
654 /*CodeGenConfig */ Config::get().getCGConfig(KernelInfo.getName()),
655 /*DumpIR*/ Config::get().ProteusDumpLLVMIR,
656 /*RelinkGlobalsByCopy*/ Config::get().ProteusRelinkGlobalsByCopy});
657 }
658
659 // Compilation is pending, try to get the compilation result buffer. If
660 // buffer is null, compilation is not done, so execute the AOT version
661 // directly.
662 ObjBuf = Compiler.takeCompilationResult(
663 HashValue, Config::get().ProteusAsyncTestBlocking);
664 if (!ObjBuf) {
665 return launchKernelDirect(KernelInfo.getKernel(), GridDim, BlockDim,
667 }
668 } else {
669 // Process through synchronous compilation.
671 KernelBitcode, HashValue, KernelInfo.getName(), Suffix, BlockDim,
672 GridDim, RCVec, KernelInfo.getLambdaCalleeInfo(),
673 BinInfo.getVarNameToGlobalInfo(), GlobalLinkedBinaries, DeviceArch,
674 /*CodeGenConfig */ Config::get().getCGConfig(KernelInfo.getName()),
675 /*DumpIR*/ Config::get().ProteusDumpLLVMIR,
676 /*RelinkGlobalsByCopy*/ Config::get().ProteusRelinkGlobalsByCopy});
677 }
678
679 if (!ObjBuf)
680 PROTEUS_FATAL_ERROR("Expected non-null object");
681
683 KernelMangled, ObjBuf->getBufferStart(),
685 BinInfo.getVarNameToGlobalInfo());
686
687 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName());
688 if (Config::get().ProteusUseStoredCache) {
689 ObjectCache.store(HashValue, ObjBuf->getMemBufferRef());
690 }
691
692 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
694}
695
696template <typename ImplT>
699 const char *ModuleId) {
700 CurHandle = Handle;
701 PROTEUS_DBG(Logger::logs("proteus")
702 << "Register fatbinary Handle " << Handle << " FatbinWrapper "
703 << FatbinWrapper << " Binary " << (void *)FatbinWrapper->Binary
704 << " ModuleId " << ModuleId << "\n");
705 if (FatbinWrapper->PrelinkedFatbins) {
706 // This is RDC compilation, just insert the FatbinWrapper and ignore the
707 // ModuleId coming from the link.stub.
708 HandleToBinaryInfo.try_emplace(Handle, FatbinWrapper,
710
711 // Initialize GlobalLinkedBinaries with prelinked fatbins.
712 void *Ptr = FatbinWrapper->PrelinkedFatbins[0];
713 for (int I = 0; Ptr != nullptr;
714 ++I, Ptr = FatbinWrapper->PrelinkedFatbins[I]) {
715 PROTEUS_DBG(Logger::logs("proteus")
716 << "I " << I << " PrelinkedFatbin " << Ptr << "\n");
717 GlobalLinkedBinaries.insert(Ptr);
718 }
719 } else {
720 // This is non-RDC compilation, associate the ModuleId of the JIT bitcode
721 // in the module with the FatbinWrapper.
722 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
723 HandleToBinaryInfo.try_emplace(Handle, FatbinWrapper,
725 }
726}
727
728template <typename ImplT> void JitEngineDevice<ImplT>::registerFatBinaryEnd() {
729 PROTEUS_DBG(Logger::logs("proteus") << "Register fatbinary end\n");
730 // Erase linked binaries for which we have LLVM IR code, those binaries are
731 // stored in the ModuleIdToFatBinary map.
732 for (auto &[ModuleId, FatbinWrapper] : ModuleIdToFatBinary)
733 GlobalLinkedBinaries.erase((void *)FatbinWrapper->Binary);
734
735 CurHandle = nullptr;
736}
737
738template <typename ImplT>
740 void *Handle, void *Kernel, char *KernelName,
742 PROTEUS_DBG(Logger::logs("proteus") << "Register function " << Kernel
743 << " To Handle " << Handle << "\n");
744 // NOTE: HIP RDC might call multiple times the registerFunction for the same
745 // kernel, which has weak linkage, when it comes from different translation
746 // units. Either the first or the second call can prevail and should be
747 // equivalent. We let the first one prevail.
748 if (JITKernelInfoMap.contains(Kernel)) {
749 PROTEUS_DBG(Logger::logs("proteus")
750 << "Warning: duplicate register function for kernel " +
751 std::string(KernelName)
752 << "\n");
753 return;
754 }
755
756 if (!HandleToBinaryInfo.count(Handle))
757 PROTEUS_FATAL_ERROR("Expected Handle in map");
758 BinaryInfo &BinInfo = HandleToBinaryInfo[Handle];
759
760 PROTEUS_DBG(Logger::logs("proteus")
761 << "Register function " << KernelName << " with binary handle "
762 << Handle << "\n");
763
764 JITKernelInfoMap[Kernel] =
766}
767
768template <typename ImplT>
770 const char *ModuleId) {
771 PROTEUS_DBG(Logger::logs("proteus")
772 << "Register linked binary FatBinary " << FatbinWrapper
773 << " Binary " << (void *)FatbinWrapper->Binary << " ModuleId "
774 << ModuleId << "\n");
775 if (CurHandle) {
776 if (!HandleToBinaryInfo.count(CurHandle))
777 PROTEUS_FATAL_ERROR("Expected CurHandle in map");
778
779 HandleToBinaryInfo[CurHandle].addModuleId(ModuleId);
780 } else
781 GlobalLinkedModuleIds.push_back(ModuleId);
782
783 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
784}
785
786} // namespace proteus
787
788#endif
void const char * ModuleId
Definition CompilerInterfaceDevice.cpp:33
void * FatbinWrapper
Definition CompilerInterfaceDevice.cpp:32
const void const char * VarName
Definition CompilerInterfaceDevice.cpp:21
void char * KernelName
Definition CompilerInterfaceDevice.cpp:52
void * Kernel
Definition CompilerInterfaceDevice.cpp:52
const void const char uint64_t VarSize
Definition CompilerInterfaceDevice.cpp:22
ArrayRef< RuntimeConstantInfo * > RCInfoArray
Definition CompilerInterfaceHost.cpp:25
#define PROTEUS_DBG(x)
Definition Debug.h:9
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
void getLambdaJitValues(StringRef FnName, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineHost.cpp:178
#define TIMESCOPE(x)
Definition TimeTracing.hpp:64
#define PROTEUS_TIMER_OUTPUT(x)
Definition TimeTracing.hpp:57
Definition JitEngineDevice.hpp:79
FatbinWrapperT * getFatbinWrapper() const
Definition JitEngineDevice.hpp:102
void mapGlobals()
Definition JitEngineDevice.hpp:193
void setExtractedModules(SmallVector< std::unique_ptr< Module > > &Modules)
Definition JitEngineDevice.hpp:151
std::unordered_map< std::string, GlobalVarInfo > & getVarNameToGlobalInfo()
Definition JitEngineDevice.hpp:217
MemoryBufferRef getDeviceBinary()
Definition JitEngineDevice.hpp:175
bool hasModuleHash() const
Definition JitEngineDevice.hpp:155
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.hpp:104
Module & getLinkedModule()
Definition JitEngineDevice.hpp:107
auto & getModuleIds()
Definition JitEngineDevice.hpp:221
void registerGlobalVar(const char *VarName, const void *Addr, uint64_t VarSize)
Definition JitEngineDevice.hpp:188
bool hasLinkedModule() const
Definition JitEngineDevice.hpp:106
bool hasDeviceBinary()
Definition JitEngineDevice.hpp:174
const SmallVector< std::reference_wrapper< Module > > getExtractedModules() const
Definition JitEngineDevice.hpp:142
void updateModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:158
HashT getModuleHash() const
Definition JitEngineDevice.hpp:156
bool hasExtractedModules() const
Definition JitEngineDevice.hpp:140
void addModuleId(const char *ModuleId)
Definition JitEngineDevice.hpp:184
CallGraph & getCallGraph()
Definition JitEngineDevice.hpp:165
BinaryInfo(FatbinWrapperT *FatbinWrapper, SmallVector< std::string > &&LinkedModuleIds)
Definition JitEngineDevice.hpp:95
void setModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:157
void setDeviceBinary(std::unique_ptr< MemoryBuffer > DeviceBinaryBuffer)
Definition JitEngineDevice.hpp:180
Definition CompilationTask.hpp:18
static CompilerAsync & instance(int NumThreads)
Definition CompilerAsync.hpp:49
void joinAllThreads()
Definition CompilerAsync.hpp:88
std::unique_ptr< MemoryBuffer > compile(CompilationTask &&CT)
Definition CompilerSync.hpp:21
static CompilerSync & instance()
Definition CompilerSync.hpp:16
int ProteusTraceOutput
Definition Config.hpp:313
static Config & get()
Definition Config.hpp:298
bool ProteusRelinkGlobalsByCopy
Definition Config.hpp:307
bool ProteusDumpLLVMIR
Definition Config.hpp:306
bool ProteusUseStoredCache
Definition Config.hpp:304
const CodeGenerationConfig & getCGConfig(llvm::StringRef KName="") const
Definition Config.hpp:317
Definition Func.hpp:252
Definition Hashing.hpp:20
std::string toString() const
Definition Hashing.hpp:28
Definition JitEngineDevice.hpp:224
bool hasBitcode()
Definition JitEngineDevice.hpp:259
const std::string & getName() const
Definition JitEngineDevice.hpp:250
void createStaticHash(HashT ModuleHash)
Definition JitEngineDevice.hpp:267
JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.hpp:237
void * getKernel() const
Definition JitEngineDevice.hpp:245
bool hasModule() const
Definition JitEngineDevice.hpp:252
const HashT getStaticHash() const
Definition JitEngineDevice.hpp:266
BinaryInfo & getBinaryInfo() const
Definition JitEngineDevice.hpp:254
ArrayRef< RuntimeConstantInfo * > getRCInfoArray() const
Definition JitEngineDevice.hpp:251
Module & getModule() const
Definition JitEngineDevice.hpp:253
void setModule(std::unique_ptr< llvm::Module > Mod)
Definition JitEngineDevice.hpp:255
void setLambdaCalleeInfo(SmallVector< std::pair< std::string, StringRef > > &&LambdaInfo)
Definition JitEngineDevice.hpp:274
const auto & getLambdaCalleeInfo()
Definition JitEngineDevice.hpp:273
bool hasLambdaCalleeInfo()
Definition JitEngineDevice.hpp:272
MemoryBufferRef getBitcode()
Definition JitEngineDevice.hpp:263
bool hasStaticHash() const
Definition JitEngineDevice.hpp:265
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.hpp:249
void setBitcode(std::unique_ptr< MemoryBuffer > ExtractedBitcode)
Definition JitEngineDevice.hpp:260
Definition JitEngineDevice.hpp:282
MemoryBufferRef getBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:407
~JitEngineDevice()
Definition JitEngineDevice.hpp:561
typename DeviceTraits< ImplT >::DeviceError_t DeviceError_t
Definition JitEngineDevice.hpp:284
void extractModuleAndBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:368
void registerLinkedBinary(FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:769
JitEngineDevice()
Definition JitEngineDevice.hpp:559
std::unordered_map< const void *, BinaryInfo > HandleToBinaryInfo
Definition JitEngineDevice.hpp:473
DenseMap< const void *, JITKernelInfo > JITKernelInfoMap
Definition JitEngineDevice.hpp:570
typename DeviceTraits< ImplT >::DeviceStream_t DeviceStream_t
Definition JitEngineDevice.hpp:285
void registerFatBinaryEnd()
Definition JitEngineDevice.hpp:728
MemoryCache< KernelFunction_t > CodeCache
Definition JitEngineDevice.hpp:566
Module & getModule(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:397
bool containsJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:477
std::optional< std::reference_wrapper< JITKernelInfo > > getJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:482
StorageCache ObjectCache
Definition JitEngineDevice.hpp:567
std::pair< std::unique_ptr< Module >, std::unique_ptr< MemoryBuffer > > extractKernelModule(BinaryInfo &BinInfo, StringRef KernelName, LLVMContext &Ctx)
Definition JitEngineDevice.hpp:294
SmallPtrSet< void *, 8 > GlobalLinkedBinaries
Definition JitEngineDevice.hpp:475
void registerFunction(void *Handle, void *Kernel, char *KernelName, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.hpp:739
void insertRegisterVar(void *Handle, const char *VarName, const void *Addr, uint64_t VarSize)
Definition JitEngineDevice.hpp:454
void finalize()
Definition JitEngineDevice.hpp:506
void * CurHandle
Definition JitEngineDevice.hpp:471
void registerFatBinary(void *Handle, FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:697
DeviceError_t compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, typename DeviceTraits< ImplT >::DeviceStream_t Stream)
Definition JitEngineDevice.hpp:585
std::unordered_map< std::string, FatbinWrapperT * > ModuleIdToFatBinary
Definition JitEngineDevice.hpp:472
typename DeviceTraits< ImplT >::KernelFunction_t KernelFunction_t
Definition JitEngineDevice.hpp:286
SmallVector< std::string > GlobalLinkedModuleIds
Definition JitEngineDevice.hpp:474
StringRef getDeviceArch() const
Definition JitEngineDevice.hpp:512
std::string DeviceArch
Definition JitEngineDevice.hpp:568
HashT getStaticHash(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:489
void getLambdaJitValues(JITKernelInfo &KernelInfo, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineDevice.hpp:417
Definition JitEngine.hpp:34
Definition LambdaRegistry.hpp:20
std::optional< DenseMap< StringRef, SmallVector< RuntimeConstant > >::iterator > matchJitVariableMap(StringRef FnName)
Definition LambdaRegistry.hpp:28
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:22
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.hpp:25
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:19
Definition MemoryCache.hpp:27
void printStats()
Definition MemoryCache.hpp:61
Definition StorageCache.hpp:30
void printStats()
Definition StorageCache.cpp:86
Definition TimeTracing.hpp:36
Definition Helpers.h:138
Definition StorageCache.cpp:24
std::unique_ptr< Module > cloneKernelFromModules(ArrayRef< std::reference_wrapper< Module > > Mods, StringRef EntryName)
Definition Cloning.h:497
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
cudaError_t launchKernelDirect(void *KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:21
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:56
HashT hashCombine(HashT A, HashT B)
Definition Hashing.hpp:121
void pruneIR(Module &M, bool UnsetExternallyInitialized=true)
Definition CoreLLVM.hpp:253
std::string toString(CodegenOption Option)
Definition Config.hpp:26
void * resolveDeviceGlobalAddr(const void *Addr)
Definition CoreDeviceCUDA.hpp:13
void internalize(Module &M, StringRef PreserveFunctionName)
Definition CoreLLVM.hpp:288
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.hpp:28
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.hpp:230
std::unique_ptr< Module > linkModules(LLVMContext &Ctx, SmallVector< std::unique_ptr< Module > > LinkedModules)
Definition CoreLLVM.hpp:213
Definition Hashing.hpp:147
Definition JitEngineDevice.hpp:280
Definition JitEngineDevice.hpp:72
const char * Binary
Definition JitEngineDevice.hpp:75
void ** PrelinkedFatbins
Definition JitEngineDevice.hpp:76
int32_t Magic
Definition JitEngineDevice.hpp:73
int32_t Version
Definition JitEngineDevice.hpp:74
Definition GlobalVarInfo.hpp:5