Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitEngineDevice.h
Go to the documentation of this file.
1//===-- JitEngineDevice.cpp -- Base JIT Engine Device header impl. --===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef PROTEUS_JITENGINEDEVICE_H
12#define PROTEUS_JITENGINEDEVICE_H
13
15#include "proteus/Init.h"
16#include "proteus/TimeTracing.h"
24#include "proteus/impl/Debug.h"
28#include "proteus/impl/Utils.h"
29
30#include <llvm/ADT/SmallPtrSet.h>
31#include <llvm/ADT/SmallVector.h>
32#include <llvm/ADT/StringRef.h>
33#include <llvm/Analysis/CallGraph.h>
34#include <llvm/Analysis/TargetTransformInfo.h>
35#include <llvm/Bitcode/BitcodeWriter.h>
36#include <llvm/CodeGen/CommandFlags.h>
37#include <llvm/CodeGen/MachineModuleInfo.h>
38#include <llvm/Config/llvm-config.h>
39#include <llvm/Demangle/Demangle.h>
40#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
41#include <llvm/IR/Constants.h>
42#include <llvm/IR/GlobalVariable.h>
43#include <llvm/IR/Instruction.h>
44#include <llvm/IR/Instructions.h>
45#include <llvm/IR/LLVMContext.h>
46#include <llvm/IR/LegacyPassManager.h>
47#include <llvm/IR/Module.h>
48#include <llvm/IR/ReplaceConstant.h>
49#include <llvm/IR/Type.h>
50#include <llvm/IR/Verifier.h>
51#include <llvm/IRReader/IRReader.h>
52#include <llvm/Linker/Linker.h>
53#include <llvm/MC/TargetRegistry.h>
54#include <llvm/Object/ELFObjectFile.h>
55#include <llvm/Passes/PassBuilder.h>
56#include <llvm/Support/Error.h>
57#include <llvm/Support/MemoryBuffer.h>
58#include <llvm/Support/MemoryBufferRef.h>
59#include <llvm/Target/TargetMachine.h>
60#include <llvm/Transforms/IPO/Internalize.h>
61#include <llvm/Transforms/Utils/Cloning.h>
62#include <llvm/Transforms/Utils/ModuleUtils.h>
63
64#include <cstdint>
65#include <functional>
66#include <memory>
67#include <optional>
68#include <string>
69
70namespace proteus {
71
72using namespace llvm;
73
75 int32_t Magic;
76 int32_t Version;
77 const char *Binary;
79};
80
82private:
83 FatbinWrapperT *FatbinWrapper;
84 std::unique_ptr<LLVMContext> Ctx;
85 SmallVector<std::string> LinkedModuleIds;
86 Module *LinkedModule;
87 std::optional<SmallVector<std::unique_ptr<Module>>> ExtractedModules;
88 std::optional<HashT> ExtractedModuleHash;
89 std::optional<CallGraph> ModuleCallGraph;
90 std::unique_ptr<MemoryBuffer> DeviceBinary;
91 std::unordered_map<std::string, GlobalVarInfo> VarNameToGlobalInfo;
92 std::once_flag Flag;
93
94public:
95 BinaryInfo() = default;
96 BinaryInfo(FatbinWrapperT *FatbinWrapper,
97 SmallVector<std::string> &&LinkedModuleIds)
98 : FatbinWrapper(FatbinWrapper), Ctx(std::make_unique<LLVMContext>()),
99 LinkedModuleIds(LinkedModuleIds), LinkedModule(nullptr),
100 ExtractedModules(std::nullopt), ModuleCallGraph(std::nullopt),
101 DeviceBinary(nullptr) {}
102
104
105 std::unique_ptr<LLVMContext> &getLLVMContext() { return Ctx; }
106
107 bool hasLinkedModule() const { return (LinkedModule != nullptr); }
108 Module &getLinkedModule() {
110 if (!LinkedModule) {
111 if (!hasExtractedModules())
112 reportFatalError("Expected extracted modules");
113
114 Timer T(Config::get().ProteusEnableTimers);
115 // Avoid linking when there's a single module by moving it instead and
116 // making sure it's materialized for call graph analysis.
117 if (ExtractedModules->size() == 1) {
118 LinkedModule = ExtractedModules->front().get();
119 if (auto E = LinkedModule->materializeAll())
120 reportFatalError("Error materializing " + toString(std::move(E)));
121 } else {
122 // By the LLVM API, linkModules takes ownership of module pointers in
123 // ExtractedModules and returns a new unique ptr to the linked module.
124 // We update ExtractedModules to contain and own only the generated
125 // LinkedModule.
126 auto GeneratedLinkedModule =
127 proteus::linkModules(*Ctx, std::move(ExtractedModules.value()));
128 SmallVector<std::unique_ptr<Module>> NewExtractedModules;
129 NewExtractedModules.emplace_back(std::move(GeneratedLinkedModule));
130 setExtractedModules(NewExtractedModules);
131
132 LinkedModule = ExtractedModules->front().get();
133 }
134
136 << "getLinkedModule " << T.elapsed() << " ms\n");
137 }
138
139 return *LinkedModule;
140 }
141
142 bool hasExtractedModules() const { return ExtractedModules.has_value(); }
143 const SmallVector<std::reference_wrapper<Module>>
145 // This should be called only once when cloning the kernel module to
146 // cache.
147 SmallVector<std::reference_wrapper<Module>> ModulesRef;
148 for (auto &M : ExtractedModules.value())
149 ModulesRef.emplace_back(*M);
150
151 return ModulesRef;
152 }
153 void setExtractedModules(SmallVector<std::unique_ptr<Module>> &Modules) {
154 ExtractedModules = std::move(Modules);
155 }
156
157 bool hasModuleHash() const { return ExtractedModuleHash.has_value(); }
159 if (!hasModuleHash())
160 reportFatalError("Expected module hash to be set");
161
162 return ExtractedModuleHash.value();
163 }
164 void setModuleHash(HashT HashValue) { ExtractedModuleHash = HashValue; }
165 void updateModuleHash(HashT HashValue) {
166 if (ExtractedModuleHash)
167 ExtractedModuleHash = hashCombine(ExtractedModuleHash.value(), HashValue);
168 else
169 ExtractedModuleHash = HashValue;
170 }
171
172 CallGraph &getCallGraph() {
173 if (!ModuleCallGraph.has_value()) {
174 if (!LinkedModule)
175 reportFatalError("Expected non-null linked module");
176 ModuleCallGraph.emplace(CallGraph(*LinkedModule));
177 }
178 return ModuleCallGraph.value();
179 }
180
181 bool hasDeviceBinary() { return (DeviceBinary != nullptr); }
182 MemoryBufferRef getDeviceBinary() {
183 if (!hasDeviceBinary())
184 reportFatalError("Expected non-null device binary");
185 return DeviceBinary->getMemBufferRef();
186 }
187 void setDeviceBinary(std::unique_ptr<MemoryBuffer> DeviceBinaryBuffer) {
188 DeviceBinary = std::move(DeviceBinaryBuffer);
189 }
190
191 void addModuleId(const char *ModuleId) {
192 LinkedModuleIds.push_back(ModuleId);
193 }
194
195 void insertGlobalVar(const char *VarName, const void *HostAddr,
196 const void *DeviceAddr, uint64_t VarSize) {
197 auto KV = VarNameToGlobalInfo.emplace(
198 VarName, GlobalVarInfo(HostAddr, DeviceAddr, VarSize));
199
200 auto TraceOut = [&KV]() {
201 auto GlobalName = KV.first->first;
202 auto &GVI = KV.first->second;
203
204 SmallString<128> S;
205 raw_svector_ostream OS(S);
206 OS << "[GVarInfo]: " << GlobalName << " HAddr:" << GVI.HostAddr
207 << " DevAddr:" << GVI.DevAddr << " VarSize:" << GVI.VarSize << "\n";
208
209 return S;
210 };
211
212 if (Config::get().traceSpecializations())
213 Logger::trace(TraceOut());
214 }
215
216 std::unordered_map<std::string, GlobalVarInfo> &getVarNameToGlobalInfo() {
217 return VarNameToGlobalInfo;
218 }
219
220 auto &getModuleIds() { return LinkedModuleIds; }
221};
222
224 std::optional<void *> Kernel;
225 std::unique_ptr<LLVMContext> Ctx;
226 std::string Name;
227 ArrayRef<RuntimeConstantInfo *> RCInfoArray;
228 std::optional<std::unique_ptr<Module>> ExtractedModule;
229 std::optional<std::unique_ptr<MemoryBuffer>> Bitcode;
230 std::optional<std::reference_wrapper<BinaryInfo>> BinInfo;
231 std::optional<HashT> StaticHash;
232 std::optional<SmallVector<std::pair<std::string, StringRef>>>
233 LambdaCalleeInfo;
234
235public:
236 JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name,
237 ArrayRef<RuntimeConstantInfo *> RCInfoArray)
238 : Kernel(Kernel), Ctx(std::make_unique<LLVMContext>()), Name(Name),
239 RCInfoArray(RCInfoArray), ExtractedModule(std::nullopt),
240 Bitcode{std::nullopt}, BinInfo(BinInfo),
241 LambdaCalleeInfo(std::nullopt) {}
242
243 JITKernelInfo() = default;
244 void *getKernel() const {
245 assert(Kernel.has_value() && "Expected Kernel is inited");
246 return Kernel.value();
247 }
248 std::unique_ptr<LLVMContext> &getLLVMContext() { return Ctx; }
249 const std::string &getName() const { return Name; }
250 ArrayRef<RuntimeConstantInfo *> getRCInfoArray() const { return RCInfoArray; }
251 bool hasModule() const { return ExtractedModule.has_value(); }
252 Module &getModule() const { return *ExtractedModule->get(); }
253 BinaryInfo &getBinaryInfo() const { return BinInfo.value(); }
254 void setModule(std::unique_ptr<llvm::Module> Mod) {
255 ExtractedModule = std::move(Mod);
256 }
257
258 bool hasBitcode() { return Bitcode.has_value(); }
259 void setBitcode(std::unique_ptr<MemoryBuffer> ExtractedBitcode) {
260 Bitcode = std::move(ExtractedBitcode);
261 }
262 MemoryBufferRef getBitcode() { return Bitcode.value()->getMemBufferRef(); }
263
264 bool hasStaticHash() const { return StaticHash.has_value(); }
265 const HashT getStaticHash() const { return StaticHash.value(); }
266 void createStaticHash(HashT ModuleHash) {
267 StaticHash = hash(Name);
268 StaticHash = hashCombine(StaticHash.value(), ModuleHash);
269 }
270
271 bool hasLambdaCalleeInfo() { return LambdaCalleeInfo.has_value(); }
272 const auto &getLambdaCalleeInfo() { return LambdaCalleeInfo.value(); }
274 SmallVector<std::pair<std::string, StringRef>> &&LambdaInfo) {
275 LambdaCalleeInfo = std::move(LambdaInfo);
276 }
277};
278
279template <typename ImplT> struct DeviceTraits;
280
281template <typename ImplT> class JitEngineDevice : public JitEngine {
282public:
286
288 compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim,
289 void **KernelArgs, uint64_t ShmemSize,
291
292 std::pair<std::unique_ptr<Module>, std::unique_ptr<MemoryBuffer>>
294 LLVMContext &Ctx) {
296 std::unique_ptr<Module> KernelModule =
297 static_cast<ImplT &>(*this).tryExtractKernelModule(BinInfo, KernelName,
298 Ctx);
299 std::unique_ptr<MemoryBuffer> Bitcode = nullptr;
300
301 // If there is no ready-made kernel module from AOT, extract per-TU or the
302 // single linked module and clone the kernel module.
303 if (!KernelModule) {
304 Timer T(Config::get().ProteusEnableTimers);
305 if (!BinInfo.hasExtractedModules())
306 static_cast<ImplT &>(*this).extractModules(BinInfo);
307
308 std::unique_ptr<Module> KernelModuleTmp = nullptr;
309 switch (Config::get().ProteusKernelClone) {
311 auto &LinkedModule = BinInfo.getLinkedModule();
312 KernelModule = llvm::CloneModule(LinkedModule);
313 break;
314 }
316 auto &LinkedModule = BinInfo.getLinkedModule();
317 KernelModule =
319 break;
320 }
322 KernelModule = proteus::cloneKernelFromModules(
324 break;
325 }
326 default:
327 reportFatalError("Unsupported kernel cloning option");
328 }
329
331 << "Cloning "
332 << toString(Config::get().ProteusKernelClone) << " "
333 << T.elapsed() << " ms\n");
334 }
335
336 // Internalize and cleanup to simplify the module and prepare it for
337 // optimization.
338 internalize(*KernelModule, KernelName);
339 proteus::runCleanupPassPipeline(*KernelModule);
340
341 // If the module is not in the provided context due to cloning, roundtrip
342 // it using bitcode. Re-use the roundtrip bitcode to return it.
343 if (&KernelModule->getContext() != &Ctx) {
344 SmallVector<char> CloneBuffer;
345 raw_svector_ostream OS(CloneBuffer);
346 WriteBitcodeToFile(*KernelModule, OS);
347 StringRef CloneStr = StringRef(CloneBuffer.data(), CloneBuffer.size());
348 auto ExpectedKernelModule =
349 parseBitcodeFile(MemoryBufferRef{CloneStr, KernelName}, Ctx);
350 if (auto E = ExpectedKernelModule.takeError())
351 reportFatalError("Error parsing bitcode: " + toString(std::move(E)));
352
353 KernelModule = std::move(*ExpectedKernelModule);
354 Bitcode = MemoryBuffer::getMemBufferCopy(CloneStr);
355 } else {
356 // Parse the kernel module to create the bitcode since it has not been
357 // created by roundtripping.
358 SmallVector<char> BitcodeBuffer;
359 raw_svector_ostream OS(BitcodeBuffer);
360 WriteBitcodeToFile(*KernelModule, OS);
361 auto BitcodeStr = StringRef{BitcodeBuffer.data(), BitcodeBuffer.size()};
362 Bitcode = MemoryBuffer::getMemBufferCopy(BitcodeStr);
363 }
364
365 return std::make_pair(std::move(KernelModule), std::move(Bitcode));
366 }
367
370
371 if (KernelInfo.hasModule() && KernelInfo.hasBitcode())
372 return;
373
374 if (KernelInfo.hasModule())
375 reportFatalError("Unexpected KernelInfo has module but not bitcode");
376
377 if (KernelInfo.hasBitcode())
378 reportFatalError("Unexpected KernelInfo has bitcode but not module");
379
380 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
381
382 Timer T(Config::get().ProteusEnableTimers);
383 auto [KernelModule, BitcodeBuffer] = extractKernelModule(
384 BinInfo, KernelInfo.getName(), *KernelInfo.getLLVMContext());
385
386 if (!KernelModule)
387 reportFatalError("Expected non-null kernel module");
388 if (!BitcodeBuffer)
389 reportFatalError("Expected non-null kernel bitcode");
390
391 KernelInfo.setModule(std::move(KernelModule));
392 KernelInfo.setBitcode(std::move(BitcodeBuffer));
394 << "Extract kernel module " << T.elapsed() << " ms\n");
395 }
396
397 Module &getModule(JITKernelInfo &KernelInfo) {
398 if (!KernelInfo.hasModule())
399 extractModuleAndBitcode(KernelInfo);
400
401 if (!KernelInfo.hasModule())
402 reportFatalError("Expected module in KernelInfo");
403
404 return KernelInfo.getModule();
405 }
406
407 MemoryBufferRef getBitcode(JITKernelInfo &KernelInfo) {
408 if (!KernelInfo.hasBitcode())
409 extractModuleAndBitcode(KernelInfo);
410
411 if (!KernelInfo.hasBitcode())
412 reportFatalError("Expected bitcode in KernelInfo");
413
414 return KernelInfo.getBitcode();
415 }
416
418 SmallVector<RuntimeConstant> &LambdaJitValuesVec) {
421 if (LR.empty()) {
422 KernelInfo.setLambdaCalleeInfo({});
423 return;
424 }
425
426 if (!KernelInfo.hasLambdaCalleeInfo()) {
427 Module &KernelModule = getModule(KernelInfo);
428 PROTEUS_DBG(Logger::logs("proteus")
429 << "=== LAMBDA MATCHING\n"
430 << "Caller trigger " << KernelInfo.getName() << " -> "
431 << demangle(KernelInfo.getName()) << "\n");
432
433 SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo;
434 for (auto &F : KernelModule.getFunctionList()) {
435 PROTEUS_DBG(Logger::logs("proteus")
436 << " Trying F " << demangle(F.getName().str()) << "\n ");
437 auto OptionalMapIt =
439 if (OptionalMapIt)
440 LambdaCalleeInfo.emplace_back(F.getName(),
441 OptionalMapIt.value()->first);
442 }
443
444 KernelInfo.setLambdaCalleeInfo(std::move(LambdaCalleeInfo));
445 }
446
447 for (auto &[FnName, LambdaType] : KernelInfo.getLambdaCalleeInfo()) {
448 const SmallVector<RuntimeConstant> &Values =
449 LR.getJitVariables(LambdaType);
450 LambdaJitValuesVec.insert(LambdaJitValuesVec.end(), Values.begin(),
451 Values.end());
452 }
453 }
454
455 void registerVar(void *Handle, const char *VarName, const void *HostAddr,
456 uint64_t VarSize) {
457 if (!HandleToBinaryInfo.count(Handle))
458 reportFatalError("Expected Handle in map");
459 BinaryInfo &BinInfo = HandleToBinaryInfo[Handle];
460
461 void *DeviceAddr = resolveDeviceGlobalAddr(HostAddr);
462 assert(DeviceAddr &&
463 "Expected non-null device address for global variable");
464
465 BinInfo.insertGlobalVar(VarName, HostAddr, DeviceAddr, VarSize);
466 }
467
469 const char *ModuleId);
471 const char *ModuleId);
473 void registerFunction(void *Handle, void *Kernel, char *KernelName,
474 ArrayRef<RuntimeConstantInfo *> RCInfoArray);
475
476 std::unordered_map<std::string, FatbinWrapperT *> ModuleIdToFatBinary;
477 std::unordered_map<const void *, BinaryInfo> HandleToBinaryInfo;
478 SmallPtrSet<void *, 8> GlobalLinkedBinaries;
479
480 bool containsJITKernelInfo(const void *Func) {
481 return JITKernelInfoMap.contains(Func);
482 }
483
484 std::optional<std::reference_wrapper<JITKernelInfo>>
485 getJITKernelInfo(const void *Func) {
487 return std::nullopt;
488 }
489 return JITKernelInfoMap[Func];
490 }
491
493 if (KernelInfo.hasStaticHash())
494 return KernelInfo.getStaticHash();
495
496 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
497
498 if (BinInfo.hasModuleHash()) {
499 KernelInfo.createStaticHash(BinInfo.getModuleHash());
500 return KernelInfo.getStaticHash();
501 }
502
503 HashT ModuleHash = static_cast<ImplT &>(*this).getModuleHash(BinInfo);
504
505 KernelInfo.createStaticHash(BinInfo.getModuleHash());
506 return KernelInfo.getStaticHash();
507 }
508
509public:
510 StringRef getDeviceArch() const { return DeviceArch; }
511
512protected:
516
517 for (auto &[Handle, FatbinInfo] : JitEngineInfo.FatbinaryMap) {
519 Handle, reinterpret_cast<FatbinWrapperT *>(FatbinInfo.FatbinWrapper),
520 FatbinInfo.ModuleId);
521
522 for (auto &LinkedBin : FatbinInfo.LinkedBinaries)
524 Handle, reinterpret_cast<FatbinWrapperT *>(LinkedBin.FatbinWrapper),
525 LinkedBin.ModuleId);
526
527 for (auto &Func : FatbinInfo.Functions)
528 registerFunction(Handle, Func.Kernel, Func.KernelName,
529 Func.RCInfoArray);
530
531 for (auto &Var : FatbinInfo.Vars)
532 registerVar(Var.Handle, Var.VarName, Var.HostAddr, Var.VarSize);
533 }
534
536
537 if (Config::get().ProteusUseStoredCache)
538 CacheChain.emplace("JitEngineDevice");
539
540 if (Config::get().ProteusAsyncCompilation)
542 std::make_unique<CompilerAsync>(Config::get().ProteusAsyncThreads);
543 }
544
546 // Thread joining is handled by CompilerAsync's shutdown guard to ensure it
547 // happens before static objects are destroyed. If this destructor does run,
548 // joinAllThreads() is idempotent.
549 if (AsyncCompiler)
550 AsyncCompiler->joinAllThreads();
551
552 if (Config::get().traceCacheStats())
555 if (Config::get().traceCacheStats() && CacheChain)
556 CacheChain->printStats();
557 }
558
560 std::optional<ObjectCacheChain> CacheChain;
561 std::string DeviceArch;
562
563 DenseMap<const void *, JITKernelInfo> JITKernelInfoMap;
564 std::unique_ptr<CompilerAsync> AsyncCompiler;
565};
566
567template <typename ImplT>
570 JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs,
571 uint64_t ShmemSize, typename DeviceTraits<ImplT>::DeviceStream_t Stream) {
572 TIMESCOPE(JitEngineDevice, compileAndRun);
573
574 auto &BinInfo = KernelInfo.getBinaryInfo();
575
576 SmallVector<RuntimeConstant> RCVec =
577 getRuntimeConstantValues(KernelArgs, KernelInfo.getRCInfoArray());
578
579 SmallVector<RuntimeConstant> LambdaJitValuesVec;
580 getLambdaJitValues(KernelInfo, LambdaJitValuesVec);
581 const auto &CGConfig = Config::get().getCGConfig(KernelInfo.getName());
582 // Include codegen and runtime specialization policy in the cache key. If we
583 // do not specialize IR based on grid dimensions, avoid hashing on grid dims
584 // to eliminate repeated compilation overhead.
585 HashT HashValue =
586 hash(getStaticHash(KernelInfo), RCVec, LambdaJitValuesVec, BlockDim.x,
587 BlockDim.y, BlockDim.z, hashCodeGenConfig(CGConfig),
589 if (CGConfig.specializeDims() || CGConfig.specializeDimsRange())
590 HashValue = hash(HashValue, GridDim.x, GridDim.y, GridDim.z);
591
592 typename DeviceTraits<ImplT>::KernelFunction_t KernelFunc =
593 CodeCache.lookup(HashValue);
594 if (KernelFunc)
595 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
596 ShmemSize, Stream);
597
598 // NOTE: we don't need a suffix to differentiate kernels, each
599 // specialization will be in its own module uniquely identify by HashValue.
600 // It exists only for debugging purposes to verify that the jitted kernel
601 // executes.
602 std::string Suffix = HashValue.toMangledSuffix();
603 std::string KernelMangled = (KernelInfo.getName() + Suffix);
604
605 if (CacheChain) {
606 auto CompiledLib = CacheChain->lookup(HashValue);
607 if (CompiledLib) {
609 relinkGlobalsObject(CompiledLib->ObjectModule->getMemBufferRef(),
610 BinInfo.getVarNameToGlobalInfo());
611
612 auto KernelFunc = proteus::getKernelFunctionFromImage(
613 KernelMangled, CompiledLib->ObjectModule->getBufferStart(),
615 BinInfo.getVarNameToGlobalInfo());
616
617 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName());
618
619 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
620 ShmemSize, Stream);
621 }
622 }
623
624 MemoryBufferRef KernelBitcode = getBitcode(KernelInfo);
625 std::unique_ptr<MemoryBuffer> ObjBuf = nullptr;
626
627 if (Config::get().ProteusAsyncCompilation) {
628 // If there is no compilation pending for the specialization, post the
629 // compilation task to the compiler.
630 if (!AsyncCompiler->isCompilationPending(HashValue)) {
631 PROTEUS_DBG(Logger::logs("proteus") << "Compile async for HashValue "
632 << HashValue.toString() << "\n");
633
634 AsyncCompiler->compile(CompilationTask{
635 KernelBitcode, HashValue, KernelInfo.getName(), Suffix, BlockDim,
636 GridDim, RCVec, KernelInfo.getLambdaCalleeInfo(),
637 BinInfo.getVarNameToGlobalInfo(), GlobalLinkedBinaries, DeviceArch,
638 /*CodeGenConfig */ CGConfig,
639 /*DumpIR*/ Config::get().ProteusDumpLLVMIR,
640 /*RelinkGlobalsByCopy*/ Config::get().ProteusRelinkGlobalsByCopy});
641 }
642
643 // Compilation is pending, try to get the compilation result buffer. If
644 // buffer is null, compilation is not done, so execute the AOT version
645 // directly.
646 ObjBuf = AsyncCompiler->takeCompilationResult(
647 HashValue, Config::get().ProteusAsyncTestBlocking);
648 if (!ObjBuf) {
649 return launchKernelDirect(KernelInfo.getKernel(), GridDim, BlockDim,
650 KernelArgs, ShmemSize, Stream);
651 }
652 } else {
653 // Process through synchronous compilation.
655 KernelBitcode, HashValue, KernelInfo.getName(), Suffix, BlockDim,
656 GridDim, RCVec, KernelInfo.getLambdaCalleeInfo(),
657 BinInfo.getVarNameToGlobalInfo(), GlobalLinkedBinaries, DeviceArch,
658 /*CodeGenConfig */ CGConfig,
659 /*DumpIR*/ Config::get().ProteusDumpLLVMIR,
660 /*RelinkGlobalsByCopy*/ Config::get().ProteusRelinkGlobalsByCopy});
661 }
662
663 if (!ObjBuf)
664 reportFatalError("Expected non-null object");
665
667 KernelMangled, ObjBuf->getBufferStart(),
669 BinInfo.getVarNameToGlobalInfo());
670
671 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName());
672 if (CacheChain)
673 CacheChain->store(HashValue,
674 CacheEntry::staticObject(ObjBuf->getMemBufferRef()));
675
676 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
677 ShmemSize, Stream);
678}
679
680template <typename ImplT>
683 const char *ModuleId) {
685 PROTEUS_DBG(Logger::logs("proteus")
686 << "Register fatbinary Handle " << Handle << " FatbinWrapper "
687 << FatbinWrapper << " Binary " << (void *)FatbinWrapper->Binary
688 << " ModuleId " << ModuleId << "\n");
689 if (FatbinWrapper->PrelinkedFatbins) {
690 // This is RDC compilation, just insert the FatbinWrapper and ignore the
691 // ModuleId coming from the link.stub.
692 HandleToBinaryInfo.try_emplace(Handle, FatbinWrapper,
693 SmallVector<std::string>{});
694
695 // Initialize GlobalLinkedBinaries with prelinked fatbins.
696 void *Ptr = FatbinWrapper->PrelinkedFatbins[0];
697 for (int I = 0; Ptr != nullptr;
698 ++I, Ptr = FatbinWrapper->PrelinkedFatbins[I]) {
699 PROTEUS_DBG(Logger::logs("proteus")
700 << "I " << I << " PrelinkedFatbin " << Ptr << "\n");
701 GlobalLinkedBinaries.insert(Ptr);
702 }
703 } else {
704 // This is non-RDC compilation, associate the ModuleId of the JIT bitcode
705 // in the module with the FatbinWrapper.
706 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
707 HandleToBinaryInfo.try_emplace(Handle, FatbinWrapper,
708 SmallVector<std::string>{ModuleId});
709 }
710}
711
712template <typename ImplT> void JitEngineDevice<ImplT>::finalizeRegistration() {
713 TIMESCOPE(JitEngineDevice, finalizeRegistration);
714 PROTEUS_DBG(Logger::logs("proteus") << "Finalize registration\n");
715 // Erase linked binaries for which we have LLVM IR code, those binaries are
716 // stored in the ModuleIdToFatBinary map.
717 for (auto &[ModuleId, FatbinWrapper] : ModuleIdToFatBinary)
718 GlobalLinkedBinaries.erase((void *)FatbinWrapper->Binary);
719}
720
721template <typename ImplT>
723 void *Handle, void *Kernel, char *KernelName,
724 ArrayRef<RuntimeConstantInfo *> RCInfoArray) {
725 PROTEUS_DBG(Logger::logs("proteus") << "Register function " << Kernel
726 << " To Handle " << Handle << "\n");
727 // NOTE: HIP RDC might call multiple times the registerFunction for the same
728 // kernel, which has weak linkage, when it comes from different translation
729 // units. Either the first or the second call can prevail and should be
730 // equivalent. We let the first one prevail.
731 if (JITKernelInfoMap.contains(Kernel)) {
732 PROTEUS_DBG(Logger::logs("proteus")
733 << "Warning: duplicate register function for kernel " +
734 std::string(KernelName)
735 << "\n");
736 return;
737 }
738
739 if (!HandleToBinaryInfo.count(Handle))
740 reportFatalError("Expected Handle in map");
741 BinaryInfo &BinInfo = HandleToBinaryInfo[Handle];
742
743 PROTEUS_DBG(Logger::logs("proteus")
744 << "Register function " << KernelName << " with binary handle "
745 << Handle << "\n");
746
747 JITKernelInfoMap[Kernel] =
749}
750
751template <typename ImplT>
754 const char *ModuleId) {
756 PROTEUS_DBG(Logger::logs("proteus")
757 << "Register linked binary FatBinary " << FatbinWrapper
758 << " Binary " << (void *)FatbinWrapper->Binary << " ModuleId "
759 << ModuleId << "\n");
760 if (!HandleToBinaryInfo.count(Handle))
761 reportFatalError("Expected Handle in map");
762
763 HandleToBinaryInfo[Handle].addModuleId(ModuleId);
764 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
765}
766
767} // namespace proteus
768
769#endif
void const char * ModuleId
Definition CompilerInterfaceDevice.cpp:37
void * FatbinWrapper
Definition CompilerInterfaceDevice.cpp:36
const void const char * VarName
Definition CompilerInterfaceDevice.cpp:25
auto & JitEngineInfo
Definition CompilerInterfaceDevice.cpp:60
void char * KernelName
Definition CompilerInterfaceDevice.cpp:55
void * Kernel
Definition CompilerInterfaceDevice.cpp:55
JitEngineInfo registerFatBinary(Handle, FatbinWrapper, ModuleId)
const void const char uint64_t VarSize
Definition CompilerInterfaceDevice.cpp:26
const void * HostAddr
Definition CompilerInterfaceDevice.cpp:24
JitEngineInfo registerLinkedBinary(FatbinWrapper, ModuleId)
ArrayRef< RuntimeConstantInfo * > RCInfoArray
Definition CompilerInterfaceHost.cpp:27
#define PROTEUS_TIMER_OUTPUT(x)
Definition Config.h:440
#define PROTEUS_DBG(x)
Definition Debug.h:9
void getLambdaJitValues(StringRef FnName, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineHost.cpp:178
#define TIMESCOPE(...)
Definition TimeTracing.h:66
Definition JitEngineDevice.h:81
FatbinWrapperT * getFatbinWrapper() const
Definition JitEngineDevice.h:103
void setExtractedModules(SmallVector< std::unique_ptr< Module > > &Modules)
Definition JitEngineDevice.h:153
std::unordered_map< std::string, GlobalVarInfo > & getVarNameToGlobalInfo()
Definition JitEngineDevice.h:216
MemoryBufferRef getDeviceBinary()
Definition JitEngineDevice.h:182
bool hasModuleHash() const
Definition JitEngineDevice.h:157
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.h:105
Module & getLinkedModule()
Definition JitEngineDevice.h:108
auto & getModuleIds()
Definition JitEngineDevice.h:220
bool hasLinkedModule() const
Definition JitEngineDevice.h:107
bool hasDeviceBinary()
Definition JitEngineDevice.h:181
const SmallVector< std::reference_wrapper< Module > > getExtractedModules() const
Definition JitEngineDevice.h:144
void updateModuleHash(HashT HashValue)
Definition JitEngineDevice.h:165
HashT getModuleHash() const
Definition JitEngineDevice.h:158
void insertGlobalVar(const char *VarName, const void *HostAddr, const void *DeviceAddr, uint64_t VarSize)
Definition JitEngineDevice.h:195
bool hasExtractedModules() const
Definition JitEngineDevice.h:142
void addModuleId(const char *ModuleId)
Definition JitEngineDevice.h:191
CallGraph & getCallGraph()
Definition JitEngineDevice.h:172
BinaryInfo(FatbinWrapperT *FatbinWrapper, SmallVector< std::string > &&LinkedModuleIds)
Definition JitEngineDevice.h:96
void setModuleHash(HashT HashValue)
Definition JitEngineDevice.h:164
void setDeviceBinary(std::unique_ptr< MemoryBuffer > DeviceBinaryBuffer)
Definition JitEngineDevice.h:187
Definition CompilationTask.h:19
std::unique_ptr< MemoryBuffer > compile(CompilationTask &&CT)
Definition CompilerSync.h:22
static CompilerSync & instance()
Definition CompilerSync.h:17
static Config & get()
Definition Config.h:334
bool ProteusRelinkGlobalsByCopy
Definition Config.h:343
bool ProteusDumpLLVMIR
Definition Config.h:342
const CodeGenerationConfig & getCGConfig(llvm::StringRef KName="") const
Definition Config.h:358
Definition Func.h:296
Definition Hashing.h:22
std::string toString() const
Definition Hashing.h:30
std::string toMangledSuffix() const
Definition Hashing.h:33
Definition JitEngineDevice.h:223
bool hasBitcode()
Definition JitEngineDevice.h:258
const std::string & getName() const
Definition JitEngineDevice.h:249
void createStaticHash(HashT ModuleHash)
Definition JitEngineDevice.h:266
JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.h:236
void * getKernel() const
Definition JitEngineDevice.h:244
bool hasModule() const
Definition JitEngineDevice.h:251
const HashT getStaticHash() const
Definition JitEngineDevice.h:265
BinaryInfo & getBinaryInfo() const
Definition JitEngineDevice.h:253
ArrayRef< RuntimeConstantInfo * > getRCInfoArray() const
Definition JitEngineDevice.h:250
Module & getModule() const
Definition JitEngineDevice.h:252
void setModule(std::unique_ptr< llvm::Module > Mod)
Definition JitEngineDevice.h:254
void setLambdaCalleeInfo(SmallVector< std::pair< std::string, StringRef > > &&LambdaInfo)
Definition JitEngineDevice.h:273
const auto & getLambdaCalleeInfo()
Definition JitEngineDevice.h:272
bool hasLambdaCalleeInfo()
Definition JitEngineDevice.h:271
MemoryBufferRef getBitcode()
Definition JitEngineDevice.h:262
bool hasStaticHash() const
Definition JitEngineDevice.h:264
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.h:248
void setBitcode(std::unique_ptr< MemoryBuffer > ExtractedBitcode)
Definition JitEngineDevice.h:259
Definition JitEngineDevice.h:281
MemoryBufferRef getBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.h:407
~JitEngineDevice()
Definition JitEngineDevice.h:545
typename DeviceTraits< ImplT >::DeviceError_t DeviceError_t
Definition JitEngineDevice.h:283
void extractModuleAndBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.h:368
JitEngineDevice()
Definition JitEngineDevice.h:513
std::unique_ptr< CompilerAsync > AsyncCompiler
Definition JitEngineDevice.h:564
std::unordered_map< const void *, BinaryInfo > HandleToBinaryInfo
Definition JitEngineDevice.h:477
DenseMap< const void *, JITKernelInfo > JITKernelInfoMap
Definition JitEngineDevice.h:563
typename DeviceTraits< ImplT >::DeviceStream_t DeviceStream_t
Definition JitEngineDevice.h:284
void finalizeRegistration()
Definition JitEngineDevice.h:712
MemoryCache< KernelFunction_t > CodeCache
Definition JitEngineDevice.h:559
Module & getModule(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.h:397
bool containsJITKernelInfo(const void *Func)
Definition JitEngineDevice.h:480
std::optional< std::reference_wrapper< JITKernelInfo > > getJITKernelInfo(const void *Func)
Definition JitEngineDevice.h:485
std::pair< std::unique_ptr< Module >, std::unique_ptr< MemoryBuffer > > extractKernelModule(BinaryInfo &BinInfo, StringRef KernelName, LLVMContext &Ctx)
Definition JitEngineDevice.h:293
SmallPtrSet< void *, 8 > GlobalLinkedBinaries
Definition JitEngineDevice.h:478
void registerFunction(void *Handle, void *Kernel, char *KernelName, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.h:722
std::optional< ObjectCacheChain > CacheChain
Definition JitEngineDevice.h:560
void registerFatBinary(void *Handle, FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.h:681
DeviceError_t compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, typename DeviceTraits< ImplT >::DeviceStream_t Stream)
Definition JitEngineDevice.h:569
void registerLinkedBinary(void *Handle, FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.h:752
std::unordered_map< std::string, FatbinWrapperT * > ModuleIdToFatBinary
Definition JitEngineDevice.h:476
typename DeviceTraits< ImplT >::KernelFunction_t KernelFunction_t
Definition JitEngineDevice.h:285
StringRef getDeviceArch() const
Definition JitEngineDevice.h:510
void registerVar(void *Handle, const char *VarName, const void *HostAddr, uint64_t VarSize)
Definition JitEngineDevice.h:455
std::string DeviceArch
Definition JitEngineDevice.h:561
HashT getStaticHash(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.h:492
void getLambdaJitValues(JITKernelInfo &KernelInfo, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineDevice.h:417
static JitEngineInfoRegistry & instance()
Definition JitEngineInfoRegistry.h:58
Definition JitEngine.h:32
Definition LambdaRegistry.h:19
std::optional< DenseMap< StringRef, SmallVector< RuntimeConstant > >::iterator > matchJitVariableMap(StringRef FnName)
Definition LambdaRegistry.h:32
const SmallVector< RuntimeConstant > & getJitVariables(StringRef LambdaTypeRef)
Definition LambdaRegistry.h:98
static LambdaRegistry & instance()
Definition LambdaRegistry.h:21
bool empty()
Definition LambdaRegistry.h:102
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.h:25
static void trace(llvm::StringRef Msg)
Definition Logger.h:30
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.h:19
Definition MemoryCache.h:31
void printStats()
Definition MemoryCache.h:62
void printKernelTrace()
Definition MemoryCache.h:76
Definition TimeTracing.h:33
uint64_t elapsed()
Definition TimeTracing.cpp:66
Definition CompiledLibrary.h:7
Definition MemoryCache.h:27
std::unique_ptr< Module > cloneKernelFromModules(ArrayRef< std::reference_wrapper< Module > > Mods, StringRef EntryName, function_ref< bool(const GlobalValue *)> ShouldCloneDefinition=nullptr)
Definition Cloning.h:513
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.h:168
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
cudaError_t launchKernelDirect(void *KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:38
HashT hashRuntimeSpecializationConfig(const CodeGenerationConfig &CGConfig)
Definition Hashing.h:152
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:78
HashT hashCodeGenConfig(const CodeGenerationConfig &CGConfig)
Definition Hashing.h:142
HashT hashCombine(HashT A, HashT B)
Definition Hashing.h:138
std::string toString(CodegenOption Option)
Definition Config.h:28
void * resolveDeviceGlobalAddr(const void *Addr)
Definition CoreDeviceCUDA.h:26
void internalize(Module &M, StringRef PreserveFunctionName)
Definition CoreLLVM.h:315
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.h:50
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.h:256
std::unique_ptr< Module > linkModules(LLVMContext &Ctx, SmallVector< std::unique_ptr< Module > > LinkedModules)
Definition CoreLLVM.h:238
Definition Hashing.h:184
static CacheEntry staticObject(MemoryBufferRef Buf)
Definition ObjectCache.h:31
Definition JitEngineDevice.h:279
Definition JitEngineDevice.h:74
const char * Binary
Definition JitEngineDevice.h:77
void ** PrelinkedFatbins
Definition JitEngineDevice.h:78
int32_t Magic
Definition JitEngineDevice.h:75
int32_t Version
Definition JitEngineDevice.h:76
Definition GlobalVarInfo.h:5
Definition Var.h:15