Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitEngineDevice.h
Go to the documentation of this file.
1//===-- JitEngineDevice.cpp -- Base JIT Engine Device header impl. --===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef PROTEUS_JITENGINEDEVICE_H
12#define PROTEUS_JITENGINEDEVICE_H
13
15#include "proteus/Init.h"
23#include "proteus/impl/Debug.h"
28#include "proteus/impl/Utils.h"
29
30#include <llvm/ADT/SmallPtrSet.h>
31#include <llvm/ADT/SmallVector.h>
32#include <llvm/ADT/StringRef.h>
33#include <llvm/Analysis/CallGraph.h>
34#include <llvm/Analysis/TargetTransformInfo.h>
35#include <llvm/Bitcode/BitcodeWriter.h>
36#include <llvm/CodeGen/CommandFlags.h>
37#include <llvm/CodeGen/MachineModuleInfo.h>
38#include <llvm/Config/llvm-config.h>
39#include <llvm/Demangle/Demangle.h>
40#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
41#include <llvm/IR/Constants.h>
42#include <llvm/IR/GlobalVariable.h>
43#include <llvm/IR/Instruction.h>
44#include <llvm/IR/Instructions.h>
45#include <llvm/IR/LLVMContext.h>
46#include <llvm/IR/LegacyPassManager.h>
47#include <llvm/IR/Module.h>
48#include <llvm/IR/ReplaceConstant.h>
49#include <llvm/IR/Type.h>
50#include <llvm/IR/Verifier.h>
51#include <llvm/IRReader/IRReader.h>
52#include <llvm/Linker/Linker.h>
53#include <llvm/MC/TargetRegistry.h>
54#include <llvm/Object/ELFObjectFile.h>
55#include <llvm/Passes/PassBuilder.h>
56#include <llvm/Support/Error.h>
57#include <llvm/Support/MemoryBuffer.h>
58#include <llvm/Support/MemoryBufferRef.h>
59#include <llvm/Target/TargetMachine.h>
60#include <llvm/Transforms/IPO/Internalize.h>
61#include <llvm/Transforms/Utils/Cloning.h>
62#include <llvm/Transforms/Utils/ModuleUtils.h>
63
64#include <cstdint>
65#include <functional>
66#include <memory>
67#include <optional>
68#include <string>
69
70namespace proteus {
71
72using namespace llvm;
73
75 int32_t Magic;
76 int32_t Version;
77 const char *Binary;
79};
80
82private:
83 FatbinWrapperT *FatbinWrapper;
84 std::unique_ptr<LLVMContext> Ctx;
85 SmallVector<std::string> LinkedModuleIds;
86 Module *LinkedModule;
87 std::optional<SmallVector<std::unique_ptr<Module>>> ExtractedModules;
88 std::optional<HashT> ExtractedModuleHash;
89 std::optional<CallGraph> ModuleCallGraph;
90 std::unique_ptr<MemoryBuffer> DeviceBinary;
91 std::unordered_map<std::string, GlobalVarInfo> VarNameToGlobalInfo;
92 std::once_flag Flag;
93
94public:
95 BinaryInfo() = default;
96 BinaryInfo(FatbinWrapperT *FatbinWrapper,
97 SmallVector<std::string> &&LinkedModuleIds)
98 : FatbinWrapper(FatbinWrapper), Ctx(std::make_unique<LLVMContext>()),
99 LinkedModuleIds(LinkedModuleIds), LinkedModule(nullptr),
100 ExtractedModules(std::nullopt), ModuleCallGraph(std::nullopt),
101 DeviceBinary(nullptr) {}
102
104
105 std::unique_ptr<LLVMContext> &getLLVMContext() { return Ctx; }
106
107 bool hasLinkedModule() const { return (LinkedModule != nullptr); }
108 Module &getLinkedModule() {
109 if (!LinkedModule) {
110 if (!hasExtractedModules())
111 reportFatalError("Expected extracted modules");
112
113 Timer T;
114 // Avoid linking when there's a single module by moving it instead and
115 // making sure it's materialized for call graph analysis.
116 if (ExtractedModules->size() == 1) {
117 LinkedModule = ExtractedModules->front().get();
118 if (auto E = LinkedModule->materializeAll())
119 reportFatalError("Error materializing " + toString(std::move(E)));
120 } else {
121 // By the LLVM API, linkModules takes ownership of module pointers in
122 // ExtractedModules and returns a new unique ptr to the linked module.
123 // We update ExtractedModules to contain and own only the generated
124 // LinkedModule.
125 auto GeneratedLinkedModule =
126 proteus::linkModules(*Ctx, std::move(ExtractedModules.value()));
127 SmallVector<std::unique_ptr<Module>> NewExtractedModules;
128 NewExtractedModules.emplace_back(std::move(GeneratedLinkedModule));
129 setExtractedModules(NewExtractedModules);
130
131 LinkedModule = ExtractedModules->front().get();
132 }
133
135 << "getLinkedModule " << T.elapsed() << " ms\n");
136 }
137
138 return *LinkedModule;
139 }
140
141 bool hasExtractedModules() const { return ExtractedModules.has_value(); }
142 const SmallVector<std::reference_wrapper<Module>>
144 // This should be called only once when cloning the kernel module to
145 // cache.
146 SmallVector<std::reference_wrapper<Module>> ModulesRef;
147 for (auto &M : ExtractedModules.value())
148 ModulesRef.emplace_back(*M);
149
150 return ModulesRef;
151 }
152 void setExtractedModules(SmallVector<std::unique_ptr<Module>> &Modules) {
153 ExtractedModules = std::move(Modules);
154 }
155
156 bool hasModuleHash() const { return ExtractedModuleHash.has_value(); }
158 if (!hasModuleHash())
159 reportFatalError("Expected module hash to be set");
160
161 return ExtractedModuleHash.value();
162 }
163 void setModuleHash(HashT HashValue) { ExtractedModuleHash = HashValue; }
164 void updateModuleHash(HashT HashValue) {
165 if (ExtractedModuleHash)
166 ExtractedModuleHash = hashCombine(ExtractedModuleHash.value(), HashValue);
167 else
168 ExtractedModuleHash = HashValue;
169 }
170
171 CallGraph &getCallGraph() {
172 if (!ModuleCallGraph.has_value()) {
173 if (!LinkedModule)
174 reportFatalError("Expected non-null linked module");
175 ModuleCallGraph.emplace(CallGraph(*LinkedModule));
176 }
177 return ModuleCallGraph.value();
178 }
179
180 bool hasDeviceBinary() { return (DeviceBinary != nullptr); }
181 MemoryBufferRef getDeviceBinary() {
182 if (!hasDeviceBinary())
183 reportFatalError("Expected non-null device binary");
184 return DeviceBinary->getMemBufferRef();
185 }
186 void setDeviceBinary(std::unique_ptr<MemoryBuffer> DeviceBinaryBuffer) {
187 DeviceBinary = std::move(DeviceBinaryBuffer);
188 }
189
190 void addModuleId(const char *ModuleId) {
191 LinkedModuleIds.push_back(ModuleId);
192 }
193
194 void insertGlobalVar(const char *VarName, const void *HostAddr,
195 const void *DeviceAddr, uint64_t VarSize) {
196 auto KV = VarNameToGlobalInfo.emplace(
197 VarName, GlobalVarInfo(HostAddr, DeviceAddr, VarSize));
198
199 auto TraceOut = [&KV]() {
200 auto GlobalName = KV.first->first;
201 auto &GVI = KV.first->second;
202
203 SmallString<128> S;
204 raw_svector_ostream OS(S);
205 OS << "[GVarInfo]: " << GlobalName << " HAddr:" << GVI.HostAddr
206 << " DevAddr:" << GVI.DevAddr << " VarSize:" << GVI.VarSize << "\n";
207
208 return S;
209 };
210
211 if (Config::get().traceSpecializations())
212 Logger::trace(TraceOut());
213 }
214
215 std::unordered_map<std::string, GlobalVarInfo> &getVarNameToGlobalInfo() {
216 return VarNameToGlobalInfo;
217 }
218
219 auto &getModuleIds() { return LinkedModuleIds; }
220};
221
223 std::optional<void *> Kernel;
224 std::unique_ptr<LLVMContext> Ctx;
225 std::string Name;
226 ArrayRef<RuntimeConstantInfo *> RCInfoArray;
227 std::optional<std::unique_ptr<Module>> ExtractedModule;
228 std::optional<std::unique_ptr<MemoryBuffer>> Bitcode;
229 std::optional<std::reference_wrapper<BinaryInfo>> BinInfo;
230 std::optional<HashT> StaticHash;
231 std::optional<SmallVector<std::pair<std::string, StringRef>>>
232 LambdaCalleeInfo;
233
234public:
235 JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name,
236 ArrayRef<RuntimeConstantInfo *> RCInfoArray)
237 : Kernel(Kernel), Ctx(std::make_unique<LLVMContext>()), Name(Name),
238 RCInfoArray(RCInfoArray), ExtractedModule(std::nullopt),
239 Bitcode{std::nullopt}, BinInfo(BinInfo),
240 LambdaCalleeInfo(std::nullopt) {}
241
242 JITKernelInfo() = default;
243 void *getKernel() const {
244 assert(Kernel.has_value() && "Expected Kernel is inited");
245 return Kernel.value();
246 }
247 std::unique_ptr<LLVMContext> &getLLVMContext() { return Ctx; }
248 const std::string &getName() const { return Name; }
249 ArrayRef<RuntimeConstantInfo *> getRCInfoArray() const { return RCInfoArray; }
250 bool hasModule() const { return ExtractedModule.has_value(); }
251 Module &getModule() const { return *ExtractedModule->get(); }
252 BinaryInfo &getBinaryInfo() const { return BinInfo.value(); }
253 void setModule(std::unique_ptr<llvm::Module> Mod) {
254 ExtractedModule = std::move(Mod);
255 }
256
257 bool hasBitcode() { return Bitcode.has_value(); }
258 void setBitcode(std::unique_ptr<MemoryBuffer> ExtractedBitcode) {
259 Bitcode = std::move(ExtractedBitcode);
260 }
261 MemoryBufferRef getBitcode() { return Bitcode.value()->getMemBufferRef(); }
262
263 bool hasStaticHash() const { return StaticHash.has_value(); }
264 const HashT getStaticHash() const { return StaticHash.value(); }
265 void createStaticHash(HashT ModuleHash) {
266 StaticHash = hash(Name);
267 StaticHash = hashCombine(StaticHash.value(), ModuleHash);
268 }
269
270 bool hasLambdaCalleeInfo() { return LambdaCalleeInfo.has_value(); }
271 const auto &getLambdaCalleeInfo() { return LambdaCalleeInfo.value(); }
273 SmallVector<std::pair<std::string, StringRef>> &&LambdaInfo) {
274 LambdaCalleeInfo = std::move(LambdaInfo);
275 }
276};
277
278template <typename ImplT> struct DeviceTraits;
279
280template <typename ImplT> class JitEngineDevice : public JitEngine {
281public:
285
287 compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim,
288 void **KernelArgs, uint64_t ShmemSize,
290
291 std::pair<std::unique_ptr<Module>, std::unique_ptr<MemoryBuffer>>
293 LLVMContext &Ctx) {
294 std::unique_ptr<Module> KernelModule =
295 static_cast<ImplT &>(*this).tryExtractKernelModule(BinInfo, KernelName,
296 Ctx);
297 std::unique_ptr<MemoryBuffer> Bitcode = nullptr;
298
299 // If there is no ready-made kernel module from AOT, extract per-TU or the
300 // single linked module and clone the kernel module.
301 if (!KernelModule) {
302 Timer T;
303 if (!BinInfo.hasExtractedModules())
304 static_cast<ImplT &>(*this).extractModules(BinInfo);
305
306 std::unique_ptr<Module> KernelModuleTmp = nullptr;
307 switch (Config::get().ProteusKernelClone) {
309 auto &LinkedModule = BinInfo.getLinkedModule();
310 KernelModule = llvm::CloneModule(LinkedModule);
311 break;
312 }
314 auto &LinkedModule = BinInfo.getLinkedModule();
315 KernelModule =
317 break;
318 }
320 KernelModule = proteus::cloneKernelFromModules(
322 break;
323 }
324 default:
325 reportFatalError("Unsupported kernel cloning option");
326 }
327
329 << "Cloning "
330 << toString(Config::get().ProteusKernelClone) << " "
331 << T.elapsed() << " ms\n");
332 }
333
334 // Internalize and cleanup to simplify the module and prepare it for
335 // optimization.
336 internalize(*KernelModule, KernelName);
337 proteus::runCleanupPassPipeline(*KernelModule);
338
339 // If the module is not in the provided context due to cloning, roundtrip
340 // it using bitcode. Re-use the roundtrip bitcode to return it.
341 if (&KernelModule->getContext() != &Ctx) {
342 SmallVector<char> CloneBuffer;
343 raw_svector_ostream OS(CloneBuffer);
344 WriteBitcodeToFile(*KernelModule, OS);
345 StringRef CloneStr = StringRef(CloneBuffer.data(), CloneBuffer.size());
346 auto ExpectedKernelModule =
347 parseBitcodeFile(MemoryBufferRef{CloneStr, KernelName}, Ctx);
348 if (auto E = ExpectedKernelModule.takeError())
349 reportFatalError("Error parsing bitcode: " + toString(std::move(E)));
350
351 KernelModule = std::move(*ExpectedKernelModule);
352 Bitcode = MemoryBuffer::getMemBufferCopy(CloneStr);
353 } else {
354 // Parse the kernel module to create the bitcode since it has not been
355 // created by roundtripping.
356 SmallVector<char> BitcodeBuffer;
357 raw_svector_ostream OS(BitcodeBuffer);
358 WriteBitcodeToFile(*KernelModule, OS);
359 auto BitcodeStr = StringRef{BitcodeBuffer.data(), BitcodeBuffer.size()};
360 Bitcode = MemoryBuffer::getMemBufferCopy(BitcodeStr);
361 }
362
363 return std::make_pair(std::move(KernelModule), std::move(Bitcode));
364 }
365
367 TIMESCOPE(__FUNCTION__)
368
369 if (KernelInfo.hasModule() && KernelInfo.hasBitcode())
370 return;
371
372 if (KernelInfo.hasModule())
373 reportFatalError("Unexpected KernelInfo has module but not bitcode");
374
375 if (KernelInfo.hasBitcode())
376 reportFatalError("Unexpected KernelInfo has bitcode but not module");
377
378 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
379
380 Timer T;
381 auto [KernelModule, BitcodeBuffer] = extractKernelModule(
382 BinInfo, KernelInfo.getName(), *KernelInfo.getLLVMContext());
383
384 if (!KernelModule)
385 reportFatalError("Expected non-null kernel module");
386 if (!BitcodeBuffer)
387 reportFatalError("Expected non-null kernel bitcode");
388
389 KernelInfo.setModule(std::move(KernelModule));
390 KernelInfo.setBitcode(std::move(BitcodeBuffer));
392 << "Extract kernel module " << T.elapsed() << " ms\n");
393 }
394
395 Module &getModule(JITKernelInfo &KernelInfo) {
396 if (!KernelInfo.hasModule())
397 extractModuleAndBitcode(KernelInfo);
398
399 if (!KernelInfo.hasModule())
400 reportFatalError("Expected module in KernelInfo");
401
402 return KernelInfo.getModule();
403 }
404
405 MemoryBufferRef getBitcode(JITKernelInfo &KernelInfo) {
406 if (!KernelInfo.hasBitcode())
407 extractModuleAndBitcode(KernelInfo);
408
409 if (!KernelInfo.hasBitcode())
410 reportFatalError("Expected bitcode in KernelInfo");
411
412 return KernelInfo.getBitcode();
413 }
414
416 SmallVector<RuntimeConstant> &LambdaJitValuesVec) {
418 if (LR.empty()) {
419 KernelInfo.setLambdaCalleeInfo({});
420 return;
421 }
422
423 if (!KernelInfo.hasLambdaCalleeInfo()) {
424 Module &KernelModule = getModule(KernelInfo);
425 PROTEUS_DBG(Logger::logs("proteus")
426 << "=== LAMBDA MATCHING\n"
427 << "Caller trigger " << KernelInfo.getName() << " -> "
428 << demangle(KernelInfo.getName()) << "\n");
429
430 SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo;
431 for (auto &F : KernelModule.getFunctionList()) {
432 PROTEUS_DBG(Logger::logs("proteus")
433 << " Trying F " << demangle(F.getName().str()) << "\n ");
434 auto OptionalMapIt =
436 if (OptionalMapIt)
437 LambdaCalleeInfo.emplace_back(F.getName(),
438 OptionalMapIt.value()->first);
439 }
440
441 KernelInfo.setLambdaCalleeInfo(std::move(LambdaCalleeInfo));
442 }
443
444 for (auto &[FnName, LambdaType] : KernelInfo.getLambdaCalleeInfo()) {
445 const SmallVector<RuntimeConstant> &Values =
446 LR.getJitVariables(LambdaType);
447 LambdaJitValuesVec.insert(LambdaJitValuesVec.end(), Values.begin(),
448 Values.end());
449 }
450 }
451
452 void registerVar(void *Handle, const char *VarName, const void *HostAddr,
453 uint64_t VarSize) {
454 if (!HandleToBinaryInfo.count(Handle))
455 reportFatalError("Expected Handle in map");
456 BinaryInfo &BinInfo = HandleToBinaryInfo[Handle];
457
458 void *DeviceAddr = resolveDeviceGlobalAddr(HostAddr);
459 assert(DeviceAddr &&
460 "Expected non-null device address for global variable");
461
462 BinInfo.insertGlobalVar(VarName, HostAddr, DeviceAddr, VarSize);
463 }
464
466 const char *ModuleId);
468 const char *ModuleId);
470 void registerFunction(void *Handle, void *Kernel, char *KernelName,
471 ArrayRef<RuntimeConstantInfo *> RCInfoArray);
472
473 std::unordered_map<std::string, FatbinWrapperT *> ModuleIdToFatBinary;
474 std::unordered_map<const void *, BinaryInfo> HandleToBinaryInfo;
475 SmallPtrSet<void *, 8> GlobalLinkedBinaries;
476
477 bool containsJITKernelInfo(const void *Func) {
478 return JITKernelInfoMap.contains(Func);
479 }
480
481 std::optional<std::reference_wrapper<JITKernelInfo>>
482 getJITKernelInfo(const void *Func) {
484 return std::nullopt;
485 }
486 return JITKernelInfoMap[Func];
487 }
488
490 if (KernelInfo.hasStaticHash())
491 return KernelInfo.getStaticHash();
492
493 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
494
495 if (BinInfo.hasModuleHash()) {
496 KernelInfo.createStaticHash(BinInfo.getModuleHash());
497 return KernelInfo.getStaticHash();
498 }
499
500 HashT ModuleHash = static_cast<ImplT &>(*this).getModuleHash(BinInfo);
501
502 KernelInfo.createStaticHash(BinInfo.getModuleHash());
503 return KernelInfo.getStaticHash();
504 }
505
506public:
507 StringRef getDeviceArch() const { return DeviceArch; }
508
509protected:
512
513 for (auto &[Handle, FatbinInfo] : JitEngineInfo.FatbinaryMap) {
515 Handle, reinterpret_cast<FatbinWrapperT *>(FatbinInfo.FatbinWrapper),
516 FatbinInfo.ModuleId);
517
518 for (auto &LinkedBin : FatbinInfo.LinkedBinaries)
520 Handle, reinterpret_cast<FatbinWrapperT *>(LinkedBin.FatbinWrapper),
521 LinkedBin.ModuleId);
522
523 for (auto &Func : FatbinInfo.Functions)
524 registerFunction(Handle, Func.Kernel, Func.KernelName,
525 Func.RCInfoArray);
526
527 for (auto &Var : FatbinInfo.Vars)
528 registerVar(Var.Handle, Var.VarName, Var.HostAddr, Var.VarSize);
529 }
530
532
533 if (Config::get().ProteusUseStoredCache)
534 CacheChain.emplace("JitEngineDevice");
535
536 if (Config::get().ProteusAsyncCompilation)
538 std::make_unique<CompilerAsync>(Config::get().ProteusAsyncThreads);
539 }
540
542 // Thread joining is handled by CompilerAsync's shutdown guard to ensure it
543 // happens before static objects are destroyed. If this destructor does run,
544 // joinAllThreads() is idempotent.
545 if (AsyncCompiler)
546 AsyncCompiler->joinAllThreads();
547
548 if (Config::get().traceCacheStats())
551 if (Config::get().traceCacheStats() && CacheChain)
552 CacheChain->printStats();
553 }
554
556 std::optional<ObjectCacheChain> CacheChain;
557 std::string DeviceArch;
558
559 DenseMap<const void *, JITKernelInfo> JITKernelInfoMap;
560 std::unique_ptr<CompilerAsync> AsyncCompiler;
561};
562
563template <typename ImplT>
566 JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs,
567 uint64_t ShmemSize, typename DeviceTraits<ImplT>::DeviceStream_t Stream) {
568 TIMESCOPE("compileAndRun");
569
570 auto &BinInfo = KernelInfo.getBinaryInfo();
571
572 SmallVector<RuntimeConstant> RCVec =
573 getRuntimeConstantValues(KernelArgs, KernelInfo.getRCInfoArray());
574
575 SmallVector<RuntimeConstant> LambdaJitValuesVec;
576 getLambdaJitValues(KernelInfo, LambdaJitValuesVec);
577
578 HashT HashValue =
579 hash(getStaticHash(KernelInfo), RCVec, LambdaJitValuesVec, GridDim.x,
580 GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z);
581
582 typename DeviceTraits<ImplT>::KernelFunction_t KernelFunc =
583 CodeCache.lookup(HashValue);
584 if (KernelFunc)
585 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
586 ShmemSize, Stream);
587
588 // NOTE: we don't need a suffix to differentiate kernels, each
589 // specialization will be in its own module uniquely identify by HashValue.
590 // It exists only for debugging purposes to verify that the jitted kernel
591 // executes.
592 std::string Suffix = HashValue.toMangledSuffix();
593 std::string KernelMangled = (KernelInfo.getName() + Suffix);
594
595 if (CacheChain) {
596 auto CompiledLib = CacheChain->lookup(HashValue);
597 if (CompiledLib) {
599 relinkGlobalsObject(CompiledLib->ObjectModule->getMemBufferRef(),
600 BinInfo.getVarNameToGlobalInfo());
601
602 auto KernelFunc = proteus::getKernelFunctionFromImage(
603 KernelMangled, CompiledLib->ObjectModule->getBufferStart(),
605 BinInfo.getVarNameToGlobalInfo());
606
607 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName());
608
609 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
610 ShmemSize, Stream);
611 }
612 }
613
614 MemoryBufferRef KernelBitcode = getBitcode(KernelInfo);
615 std::unique_ptr<MemoryBuffer> ObjBuf = nullptr;
616
617 if (Config::get().ProteusAsyncCompilation) {
618 // If there is no compilation pending for the specialization, post the
619 // compilation task to the compiler.
620 if (!AsyncCompiler->isCompilationPending(HashValue)) {
621 PROTEUS_DBG(Logger::logs("proteus") << "Compile async for HashValue "
622 << HashValue.toString() << "\n");
623
624 AsyncCompiler->compile(CompilationTask{
625 KernelBitcode, HashValue, KernelInfo.getName(), Suffix, BlockDim,
626 GridDim, RCVec, KernelInfo.getLambdaCalleeInfo(),
627 BinInfo.getVarNameToGlobalInfo(), GlobalLinkedBinaries, DeviceArch,
628 /*CodeGenConfig */ Config::get().getCGConfig(KernelInfo.getName()),
629 /*DumpIR*/ Config::get().ProteusDumpLLVMIR,
630 /*RelinkGlobalsByCopy*/ Config::get().ProteusRelinkGlobalsByCopy});
631 }
632
633 // Compilation is pending, try to get the compilation result buffer. If
634 // buffer is null, compilation is not done, so execute the AOT version
635 // directly.
636 ObjBuf = AsyncCompiler->takeCompilationResult(
637 HashValue, Config::get().ProteusAsyncTestBlocking);
638 if (!ObjBuf) {
639 return launchKernelDirect(KernelInfo.getKernel(), GridDim, BlockDim,
640 KernelArgs, ShmemSize, Stream);
641 }
642 } else {
643 // Process through synchronous compilation.
645 KernelBitcode, HashValue, KernelInfo.getName(), Suffix, BlockDim,
646 GridDim, RCVec, KernelInfo.getLambdaCalleeInfo(),
647 BinInfo.getVarNameToGlobalInfo(), GlobalLinkedBinaries, DeviceArch,
648 /*CodeGenConfig */ Config::get().getCGConfig(KernelInfo.getName()),
649 /*DumpIR*/ Config::get().ProteusDumpLLVMIR,
650 /*RelinkGlobalsByCopy*/ Config::get().ProteusRelinkGlobalsByCopy});
651 }
652
653 if (!ObjBuf)
654 reportFatalError("Expected non-null object");
655
657 KernelMangled, ObjBuf->getBufferStart(),
659 BinInfo.getVarNameToGlobalInfo());
660
661 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName());
662 if (CacheChain)
663 CacheChain->store(HashValue,
664 CacheEntry::staticObject(ObjBuf->getMemBufferRef()));
665
666 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
667 ShmemSize, Stream);
668}
669
670template <typename ImplT>
673 const char *ModuleId) {
674 PROTEUS_DBG(Logger::logs("proteus")
675 << "Register fatbinary Handle " << Handle << " FatbinWrapper "
676 << FatbinWrapper << " Binary " << (void *)FatbinWrapper->Binary
677 << " ModuleId " << ModuleId << "\n");
678 if (FatbinWrapper->PrelinkedFatbins) {
679 // This is RDC compilation, just insert the FatbinWrapper and ignore the
680 // ModuleId coming from the link.stub.
681 HandleToBinaryInfo.try_emplace(Handle, FatbinWrapper,
682 SmallVector<std::string>{});
683
684 // Initialize GlobalLinkedBinaries with prelinked fatbins.
685 void *Ptr = FatbinWrapper->PrelinkedFatbins[0];
686 for (int I = 0; Ptr != nullptr;
687 ++I, Ptr = FatbinWrapper->PrelinkedFatbins[I]) {
688 PROTEUS_DBG(Logger::logs("proteus")
689 << "I " << I << " PrelinkedFatbin " << Ptr << "\n");
690 GlobalLinkedBinaries.insert(Ptr);
691 }
692 } else {
693 // This is non-RDC compilation, associate the ModuleId of the JIT bitcode
694 // in the module with the FatbinWrapper.
695 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
696 HandleToBinaryInfo.try_emplace(Handle, FatbinWrapper,
697 SmallVector<std::string>{ModuleId});
698 }
699}
700
701template <typename ImplT> void JitEngineDevice<ImplT>::finalizeRegistration() {
702 PROTEUS_DBG(Logger::logs("proteus") << "Finalize registration\n");
703 // Erase linked binaries for which we have LLVM IR code, those binaries are
704 // stored in the ModuleIdToFatBinary map.
705 for (auto &[ModuleId, FatbinWrapper] : ModuleIdToFatBinary)
706 GlobalLinkedBinaries.erase((void *)FatbinWrapper->Binary);
707}
708
709template <typename ImplT>
711 void *Handle, void *Kernel, char *KernelName,
712 ArrayRef<RuntimeConstantInfo *> RCInfoArray) {
713 PROTEUS_DBG(Logger::logs("proteus") << "Register function " << Kernel
714 << " To Handle " << Handle << "\n");
715 // NOTE: HIP RDC might call multiple times the registerFunction for the same
716 // kernel, which has weak linkage, when it comes from different translation
717 // units. Either the first or the second call can prevail and should be
718 // equivalent. We let the first one prevail.
719 if (JITKernelInfoMap.contains(Kernel)) {
720 PROTEUS_DBG(Logger::logs("proteus")
721 << "Warning: duplicate register function for kernel " +
722 std::string(KernelName)
723 << "\n");
724 return;
725 }
726
727 if (!HandleToBinaryInfo.count(Handle))
728 reportFatalError("Expected Handle in map");
729 BinaryInfo &BinInfo = HandleToBinaryInfo[Handle];
730
731 PROTEUS_DBG(Logger::logs("proteus")
732 << "Register function " << KernelName << " with binary handle "
733 << Handle << "\n");
734
735 JITKernelInfoMap[Kernel] =
737}
738
739template <typename ImplT>
742 const char *ModuleId) {
743 PROTEUS_DBG(Logger::logs("proteus")
744 << "Register linked binary FatBinary " << FatbinWrapper
745 << " Binary " << (void *)FatbinWrapper->Binary << " ModuleId "
746 << ModuleId << "\n");
747 if (!HandleToBinaryInfo.count(Handle))
748 reportFatalError("Expected Handle in map");
749
750 HandleToBinaryInfo[Handle].addModuleId(ModuleId);
751 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
752}
753
754} // namespace proteus
755
756#endif
void const char * ModuleId
Definition CompilerInterfaceDevice.cpp:36
void * FatbinWrapper
Definition CompilerInterfaceDevice.cpp:35
const void const char * VarName
Definition CompilerInterfaceDevice.cpp:24
auto & JitEngineInfo
Definition CompilerInterfaceDevice.cpp:58
void char * KernelName
Definition CompilerInterfaceDevice.cpp:54
void * Kernel
Definition CompilerInterfaceDevice.cpp:54
const void const char uint64_t VarSize
Definition CompilerInterfaceDevice.cpp:25
const void * HostAddr
Definition CompilerInterfaceDevice.cpp:23
ArrayRef< RuntimeConstantInfo * > RCInfoArray
Definition CompilerInterfaceHost.cpp:26
#define PROTEUS_DBG(x)
Definition Debug.h:9
void getLambdaJitValues(StringRef FnName, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineHost.cpp:178
#define TIMESCOPE(x)
Definition TimeTracing.h:59
#define PROTEUS_TIMER_OUTPUT(x)
Definition TimeTracing.h:54
Definition JitEngineDevice.h:81
FatbinWrapperT * getFatbinWrapper() const
Definition JitEngineDevice.h:103
void setExtractedModules(SmallVector< std::unique_ptr< Module > > &Modules)
Definition JitEngineDevice.h:152
std::unordered_map< std::string, GlobalVarInfo > & getVarNameToGlobalInfo()
Definition JitEngineDevice.h:215
MemoryBufferRef getDeviceBinary()
Definition JitEngineDevice.h:181
bool hasModuleHash() const
Definition JitEngineDevice.h:156
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.h:105
Module & getLinkedModule()
Definition JitEngineDevice.h:108
auto & getModuleIds()
Definition JitEngineDevice.h:219
bool hasLinkedModule() const
Definition JitEngineDevice.h:107
bool hasDeviceBinary()
Definition JitEngineDevice.h:180
const SmallVector< std::reference_wrapper< Module > > getExtractedModules() const
Definition JitEngineDevice.h:143
void updateModuleHash(HashT HashValue)
Definition JitEngineDevice.h:164
HashT getModuleHash() const
Definition JitEngineDevice.h:157
void insertGlobalVar(const char *VarName, const void *HostAddr, const void *DeviceAddr, uint64_t VarSize)
Definition JitEngineDevice.h:194
bool hasExtractedModules() const
Definition JitEngineDevice.h:141
void addModuleId(const char *ModuleId)
Definition JitEngineDevice.h:190
CallGraph & getCallGraph()
Definition JitEngineDevice.h:171
BinaryInfo(FatbinWrapperT *FatbinWrapper, SmallVector< std::string > &&LinkedModuleIds)
Definition JitEngineDevice.h:96
void setModuleHash(HashT HashValue)
Definition JitEngineDevice.h:163
void setDeviceBinary(std::unique_ptr< MemoryBuffer > DeviceBinaryBuffer)
Definition JitEngineDevice.h:186
Definition CompilationTask.h:19
std::unique_ptr< MemoryBuffer > compile(CompilationTask &&CT)
Definition CompilerSync.h:21
static CompilerSync & instance()
Definition CompilerSync.h:16
static Config & get()
Definition Config.h:334
bool ProteusRelinkGlobalsByCopy
Definition Config.h:343
bool ProteusDumpLLVMIR
Definition Config.h:342
const CodeGenerationConfig & getCGConfig(llvm::StringRef KName="") const
Definition Config.h:357
Definition Func.h:290
Definition Hashing.h:21
std::string toString() const
Definition Hashing.h:29
std::string toMangledSuffix() const
Definition Hashing.h:32
Definition JitEngineDevice.h:222
bool hasBitcode()
Definition JitEngineDevice.h:257
const std::string & getName() const
Definition JitEngineDevice.h:248
void createStaticHash(HashT ModuleHash)
Definition JitEngineDevice.h:265
JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.h:235
void * getKernel() const
Definition JitEngineDevice.h:243
bool hasModule() const
Definition JitEngineDevice.h:250
const HashT getStaticHash() const
Definition JitEngineDevice.h:264
BinaryInfo & getBinaryInfo() const
Definition JitEngineDevice.h:252
ArrayRef< RuntimeConstantInfo * > getRCInfoArray() const
Definition JitEngineDevice.h:249
Module & getModule() const
Definition JitEngineDevice.h:251
void setModule(std::unique_ptr< llvm::Module > Mod)
Definition JitEngineDevice.h:253
void setLambdaCalleeInfo(SmallVector< std::pair< std::string, StringRef > > &&LambdaInfo)
Definition JitEngineDevice.h:272
const auto & getLambdaCalleeInfo()
Definition JitEngineDevice.h:271
bool hasLambdaCalleeInfo()
Definition JitEngineDevice.h:270
MemoryBufferRef getBitcode()
Definition JitEngineDevice.h:261
bool hasStaticHash() const
Definition JitEngineDevice.h:263
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.h:247
void setBitcode(std::unique_ptr< MemoryBuffer > ExtractedBitcode)
Definition JitEngineDevice.h:258
Definition JitEngineDevice.h:280
MemoryBufferRef getBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.h:405
~JitEngineDevice()
Definition JitEngineDevice.h:541
typename DeviceTraits< ImplT >::DeviceError_t DeviceError_t
Definition JitEngineDevice.h:282
void extractModuleAndBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.h:366
JitEngineDevice()
Definition JitEngineDevice.h:510
std::unique_ptr< CompilerAsync > AsyncCompiler
Definition JitEngineDevice.h:560
std::unordered_map< const void *, BinaryInfo > HandleToBinaryInfo
Definition JitEngineDevice.h:474
DenseMap< const void *, JITKernelInfo > JITKernelInfoMap
Definition JitEngineDevice.h:559
typename DeviceTraits< ImplT >::DeviceStream_t DeviceStream_t
Definition JitEngineDevice.h:283
void finalizeRegistration()
Definition JitEngineDevice.h:701
MemoryCache< KernelFunction_t > CodeCache
Definition JitEngineDevice.h:555
Module & getModule(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.h:395
bool containsJITKernelInfo(const void *Func)
Definition JitEngineDevice.h:477
std::optional< std::reference_wrapper< JITKernelInfo > > getJITKernelInfo(const void *Func)
Definition JitEngineDevice.h:482
std::pair< std::unique_ptr< Module >, std::unique_ptr< MemoryBuffer > > extractKernelModule(BinaryInfo &BinInfo, StringRef KernelName, LLVMContext &Ctx)
Definition JitEngineDevice.h:292
SmallPtrSet< void *, 8 > GlobalLinkedBinaries
Definition JitEngineDevice.h:475
void registerFunction(void *Handle, void *Kernel, char *KernelName, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.h:710
std::optional< ObjectCacheChain > CacheChain
Definition JitEngineDevice.h:556
void registerFatBinary(void *Handle, FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.h:671
DeviceError_t compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, typename DeviceTraits< ImplT >::DeviceStream_t Stream)
Definition JitEngineDevice.h:565
void registerLinkedBinary(void *Handle, FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.h:740
std::unordered_map< std::string, FatbinWrapperT * > ModuleIdToFatBinary
Definition JitEngineDevice.h:473
typename DeviceTraits< ImplT >::KernelFunction_t KernelFunction_t
Definition JitEngineDevice.h:284
StringRef getDeviceArch() const
Definition JitEngineDevice.h:507
void registerVar(void *Handle, const char *VarName, const void *HostAddr, uint64_t VarSize)
Definition JitEngineDevice.h:452
std::string DeviceArch
Definition JitEngineDevice.h:557
HashT getStaticHash(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.h:489
void getLambdaJitValues(JITKernelInfo &KernelInfo, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineDevice.h:415
static JitEngineInfoRegistry & instance()
Definition JitEngineInfoRegistry.h:58
Definition JitEngine.h:32
Definition LambdaRegistry.h:19
std::optional< DenseMap< StringRef, SmallVector< RuntimeConstant > >::iterator > matchJitVariableMap(StringRef FnName)
Definition LambdaRegistry.h:27
const SmallVector< RuntimeConstant > & getJitVariables(StringRef LambdaTypeRef)
Definition LambdaRegistry.h:93
static LambdaRegistry & instance()
Definition LambdaRegistry.h:21
bool empty()
Definition LambdaRegistry.h:97
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.h:25
static void trace(llvm::StringRef Msg)
Definition Logger.h:30
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.h:19
Definition MemoryCache.h:30
void printStats()
Definition MemoryCache.h:61
void printKernelTrace()
Definition MemoryCache.h:74
Definition TimeTracing.h:40
uint64_t elapsed()
Definition TimeTracing.cpp:51
Definition CompiledLibrary.h:7
Definition MemoryCache.h:26
std::unique_ptr< Module > cloneKernelFromModules(ArrayRef< std::reference_wrapper< Module > > Mods, StringRef EntryName, function_ref< bool(const GlobalValue *)> ShouldCloneDefinition=nullptr)
Definition Cloning.h:513
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.h:142
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
cudaError_t launchKernelDirect(void *KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:38
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:78
HashT hashCombine(HashT A, HashT B)
Definition Hashing.h:137
std::string toString(CodegenOption Option)
Definition Config.h:28
void * resolveDeviceGlobalAddr(const void *Addr)
Definition CoreDeviceCUDA.h:26
void internalize(Module &M, StringRef PreserveFunctionName)
Definition CoreLLVM.h:288
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.h:50
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.h:230
std::unique_ptr< Module > linkModules(LLVMContext &Ctx, SmallVector< std::unique_ptr< Module > > LinkedModules)
Definition CoreLLVM.h:213
Definition Hashing.h:158
static CacheEntry staticObject(MemoryBufferRef Buf)
Definition ObjectCache.h:31
Definition JitEngineDevice.h:278
Definition JitEngineDevice.h:74
const char * Binary
Definition JitEngineDevice.h:77
void ** PrelinkedFatbins
Definition JitEngineDevice.h:78
int32_t Magic
Definition JitEngineDevice.h:75
int32_t Version
Definition JitEngineDevice.h:76
Definition GlobalVarInfo.h:5
Definition Var.h:16