Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitEngineDevice.hpp
Go to the documentation of this file.
1//===-- JitEngineDevice.cpp -- Base JIT Engine Device header impl. --===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef PROTEUS_JITENGINEDEVICE_HPP
12#define PROTEUS_JITENGINEDEVICE_HPP
13
14#include <cstdint>
15#include <functional>
16#include <llvm/ADT/SmallPtrSet.h>
17#include <llvm/Analysis/CallGraph.h>
18#include <memory>
19#include <optional>
20#include <string>
21
22#include <llvm/ADT/SmallVector.h>
23#include <llvm/ADT/StringRef.h>
24#include <llvm/Analysis/TargetTransformInfo.h>
25#include <llvm/Bitcode/BitcodeWriter.h>
26#include <llvm/CodeGen/CommandFlags.h>
27#include <llvm/CodeGen/MachineModuleInfo.h>
28#include <llvm/Config/llvm-config.h>
29#include <llvm/Demangle/Demangle.h>
30#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
31#include <llvm/IR/Constants.h>
32#include <llvm/IR/GlobalVariable.h>
33#include <llvm/IR/Instruction.h>
34#include <llvm/IR/Instructions.h>
35#include <llvm/IR/LLVMContext.h>
36#include <llvm/IR/LegacyPassManager.h>
37#include <llvm/IR/Module.h>
38#include <llvm/IR/ReplaceConstant.h>
39#include <llvm/IR/Type.h>
40#include <llvm/IR/Verifier.h>
41#include <llvm/IRReader/IRReader.h>
42#include <llvm/Linker/Linker.h>
43#include <llvm/MC/TargetRegistry.h>
44#include <llvm/Object/ELFObjectFile.h>
45#include <llvm/Passes/PassBuilder.h>
46#include <llvm/Support/Error.h>
47#include <llvm/Support/MemoryBuffer.h>
48#include <llvm/Support/MemoryBufferRef.h>
49#include <llvm/Target/TargetMachine.h>
50#include <llvm/Transforms/IPO/Internalize.h>
51#include <llvm/Transforms/Utils/Cloning.h>
52#include <llvm/Transforms/Utils/ModuleUtils.h>
53
54#include "proteus/Cloning.h"
59#include "proteus/CoreLLVM.hpp"
60#include "proteus/Debug.h"
61#include "proteus/Hashing.hpp"
62#include "proteus/JitCache.hpp"
63#include "proteus/JitEngine.hpp"
66#include "proteus/Utils.h"
67
68namespace proteus {
69
70using namespace llvm;
71
73 int32_t Magic;
74 int32_t Version;
75 const char *Binary;
77};
78
80private:
81 FatbinWrapperT *FatbinWrapper;
82 std::unique_ptr<LLVMContext> Ctx;
83 SmallVector<std::string> LinkedModuleIds;
84 Module *LinkedModule;
85 std::optional<SmallVector<std::unique_ptr<Module>>> ExtractedModules;
86 std::optional<HashT> ExtractedModuleHash;
87 std::optional<CallGraph> ModuleCallGraph;
88 std::unique_ptr<MemoryBuffer> DeviceBinary;
89
90public:
91 BinaryInfo() = default;
92 BinaryInfo(FatbinWrapperT *FatbinWrapper,
93 SmallVector<std::string> &&LinkedModuleIds)
94 : FatbinWrapper(FatbinWrapper), Ctx(std::make_unique<LLVMContext>()),
95 LinkedModuleIds(LinkedModuleIds), LinkedModule(nullptr),
96 ExtractedModules(std::nullopt), ModuleCallGraph(std::nullopt),
97 DeviceBinary(nullptr) {}
98
100
101 std::unique_ptr<LLVMContext> &getLLVMContext() { return Ctx; }
102
103 bool hasLinkedModule() const { return (LinkedModule != nullptr); }
104 Module &getLinkedModule() {
105 if (!LinkedModule) {
106 if (!hasExtractedModules())
107 PROTEUS_FATAL_ERROR("Expected extracted modules");
108
109 Timer T;
110 // Avoid linking when there's a single module by moving it instead and
111 // making sure it's materialized for call graph analysis.
112 if (ExtractedModules->size() == 1) {
113 LinkedModule = ExtractedModules->front().get();
114 if (auto E = LinkedModule->materializeAll())
115 PROTEUS_FATAL_ERROR("Error materializing " + toString(std::move(E)));
116 } else {
117 // By the LLVM API, linkModules takes ownership of module pointers in
118 // ExtractedModules and returns a new unique ptr to the linked module.
119 // We update ExtractedModules to contain and own only the generated
120 // LinkedModule.
121 auto GeneratedLinkedModule =
122 proteus::linkModules(*Ctx, std::move(ExtractedModules.value()));
123 SmallVector<std::unique_ptr<Module>> NewExtractedModules;
124 NewExtractedModules.emplace_back(std::move(GeneratedLinkedModule));
125 setExtractedModules(NewExtractedModules);
126
127 LinkedModule = ExtractedModules->front().get();
128 }
129
131 << "getLinkedModule " << T.elapsed() << " ms\n");
132 }
133
134 return *LinkedModule;
135 }
136
137 bool hasExtractedModules() const { return ExtractedModules.has_value(); }
138 const SmallVector<std::reference_wrapper<Module>>
140 // This should be called only once when cloning the kernel module to cache.
141 SmallVector<std::reference_wrapper<Module>> ModulesRef;
142 for (auto &M : ExtractedModules.value())
143 ModulesRef.emplace_back(*M);
144
145 return ModulesRef;
146 }
147 void setExtractedModules(SmallVector<std::unique_ptr<Module>> &Modules) {
148 ExtractedModules = std::move(Modules);
149 }
150
151 bool hasModuleHash() const { return ExtractedModuleHash.has_value(); }
152 HashT getModuleHash() const { return ExtractedModuleHash.value(); }
153 void setModuleHash(HashT HashValue) { ExtractedModuleHash = HashValue; }
154 void updateModuleHash(HashT HashValue) {
155 if (ExtractedModuleHash)
156 ExtractedModuleHash = hashCombine(ExtractedModuleHash.value(), HashValue);
157 else
158 ExtractedModuleHash = HashValue;
159 }
160
161 CallGraph &getCallGraph() {
162 if (!ModuleCallGraph.has_value()) {
163 if (!LinkedModule)
164 PROTEUS_FATAL_ERROR("Expected non-null linked module");
165 ModuleCallGraph.emplace(CallGraph(*LinkedModule));
166 }
167 return ModuleCallGraph.value();
168 }
169
170 bool hasDeviceBinary() { return (DeviceBinary != nullptr); }
171 MemoryBufferRef getDeviceBinary() {
172 if (!hasDeviceBinary())
173 PROTEUS_FATAL_ERROR("Expeced non-null device binary");
174 return DeviceBinary->getMemBufferRef();
175 }
176 void setDeviceBinary(std::unique_ptr<MemoryBuffer> DeviceBinaryBuffer) {
177 DeviceBinary = std::move(DeviceBinaryBuffer);
178 }
179
180 void addModuleId(const char *ModuleId) {
181 LinkedModuleIds.push_back(ModuleId);
182 }
183
184 auto &getModuleIds() { return LinkedModuleIds; }
185};
186
188 std::optional<void *> Kernel;
189 std::unique_ptr<LLVMContext> Ctx;
190 std::string Name;
191 ArrayRef<RuntimeConstantInfo *> RCInfoArray;
192 std::optional<std::unique_ptr<Module>> ExtractedModule;
193 std::optional<std::unique_ptr<MemoryBuffer>> Bitcode;
194 std::optional<std::reference_wrapper<BinaryInfo>> BinInfo;
195 std::optional<HashT> StaticHash;
196 std::optional<SmallVector<std::pair<std::string, StringRef>>>
197 LambdaCalleeInfo;
198
199public:
200 JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name,
201 ArrayRef<RuntimeConstantInfo *> RCInfoArray)
202 : Kernel(Kernel), Ctx(std::make_unique<LLVMContext>()), Name(Name),
203 RCInfoArray(RCInfoArray), ExtractedModule(std::nullopt),
204 Bitcode{std::nullopt}, BinInfo(BinInfo),
205 LambdaCalleeInfo(std::nullopt) {}
206
207 JITKernelInfo() = default;
208 void *getKernel() const {
209 assert(Kernel.has_value() && "Expected Kernel is inited");
210 return Kernel.value();
211 }
212 std::unique_ptr<LLVMContext> &getLLVMContext() { return Ctx; }
213 const std::string &getName() const { return Name; }
214 ArrayRef<RuntimeConstantInfo *> getRCInfoArray() const { return RCInfoArray; }
215 bool hasModule() const { return ExtractedModule.has_value(); }
216 Module &getModule() const { return *ExtractedModule->get(); }
217 BinaryInfo &getBinaryInfo() const { return BinInfo.value(); }
218 void setModule(std::unique_ptr<llvm::Module> Mod) {
219 ExtractedModule = std::move(Mod);
220 }
221
222 bool hasBitcode() { return Bitcode.has_value(); }
223 void setBitcode(std::unique_ptr<MemoryBuffer> ExtractedBitcode) {
224 Bitcode = std::move(ExtractedBitcode);
225 }
226 MemoryBufferRef getBitcode() { return Bitcode.value()->getMemBufferRef(); }
227
228 bool hasStaticHash() const { return StaticHash.has_value(); }
229 const HashT getStaticHash() const { return StaticHash.value(); }
230 void createStaticHash(HashT ModuleHash) {
231 StaticHash = hash(Name);
232 StaticHash = hashCombine(StaticHash.value(), ModuleHash);
233 }
234
235 bool hasLambdaCalleeInfo() { return LambdaCalleeInfo.has_value(); }
236 const auto &getLambdaCalleeInfo() { return LambdaCalleeInfo.value(); }
238 SmallVector<std::pair<std::string, StringRef>> &&LambdaInfo) {
239 LambdaCalleeInfo = std::move(LambdaInfo);
240 }
241};
242
243template <typename ImplT> struct DeviceTraits;
244
245template <typename ImplT> class JitEngineDevice : public JitEngine {
246
247public:
251
253 compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim,
254 void **KernelArgs, uint64_t ShmemSize,
256
257 std::pair<std::unique_ptr<Module>, std::unique_ptr<MemoryBuffer>>
259 LLVMContext &Ctx) {
260 std::unique_ptr<Module> KernelModule =
261 static_cast<ImplT &>(*this).tryExtractKernelModule(BinInfo, KernelName,
262 Ctx);
263 std::unique_ptr<MemoryBuffer> Bitcode = nullptr;
264
265 // If there is no ready-made kernel module from AOT, extract per-TU or the
266 // single linked module and clone the kernel module.
267 if (!KernelModule) {
268 Timer T;
269 if (!BinInfo.hasExtractedModules())
270 static_cast<ImplT &>(*this).extractModules(BinInfo);
271
272 std::unique_ptr<Module> KernelModuleTmp = nullptr;
273 switch (Config::get().ProteusKernelClone) {
275 auto &LinkedModule = BinInfo.getLinkedModule();
276 KernelModule = llvm::CloneModule(LinkedModule);
277 break;
278 }
280 auto &LinkedModule = BinInfo.getLinkedModule();
281 KernelModule =
283 break;
284 }
286 KernelModule = proteus::cloneKernelFromModules(
288 break;
289 }
290 default:
291 PROTEUS_FATAL_ERROR("Unsupported kernel cloning option");
292 }
293
295 << "Cloning "
296 << toString(Config::get().ProteusKernelClone) << " "
297 << T.elapsed() << " ms\n");
298 }
299
300 // Internalize and cleanup to simplify the module and prepare it for
301 // optimization.
302 internalize(*KernelModule, KernelName);
303 proteus::runCleanupPassPipeline(*KernelModule);
304
305 // If the module is not in the provided context due to cloning, roundtrip it
306 // using bitcode. Re-use the roundtrip bitcode to return it.
307 if (&KernelModule->getContext() != &Ctx) {
308 SmallVector<char> CloneBuffer;
309 raw_svector_ostream OS(CloneBuffer);
310 WriteBitcodeToFile(*KernelModule, OS);
311 StringRef CloneStr = StringRef(CloneBuffer.data(), CloneBuffer.size());
312 auto ExpectedKernelModule =
313 parseBitcodeFile(MemoryBufferRef{CloneStr, KernelName}, Ctx);
314 if (auto E = ExpectedKernelModule.takeError())
315 PROTEUS_FATAL_ERROR("Error parsing bitcode: " + toString(std::move(E)));
316
317 KernelModule = std::move(*ExpectedKernelModule);
318 Bitcode = MemoryBuffer::getMemBufferCopy(CloneStr);
319 } else {
320 // Parse the kernel module to create the bitcode since it has not been
321 // created by roundtripping.
322 SmallVector<char> BitcodeBuffer;
323 raw_svector_ostream OS(BitcodeBuffer);
324 WriteBitcodeToFile(*KernelModule, OS);
325 auto BitcodeStr = StringRef{BitcodeBuffer.data(), BitcodeBuffer.size()};
326 Bitcode = MemoryBuffer::getMemBufferCopy(BitcodeStr);
327 }
328
329 return std::make_pair(std::move(KernelModule), std::move(Bitcode));
330 }
331
333 TIMESCOPE(__FUNCTION__)
334
335 if (KernelInfo.hasModule() && KernelInfo.hasBitcode())
336 return;
337
338 if (KernelInfo.hasModule())
339 PROTEUS_FATAL_ERROR("Unexpected KernelInfo has module but not bitcode");
340
341 if (KernelInfo.hasBitcode())
342 PROTEUS_FATAL_ERROR("Unexpected KernelInfo has bitcode but not module");
343
344 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
345
346 Timer T;
347 auto [KernelModule, BitcodeBuffer] = extractKernelModule(
348 BinInfo, KernelInfo.getName(), *KernelInfo.getLLVMContext());
349
350 if (!KernelModule)
351 PROTEUS_FATAL_ERROR("Expected non-null kernel module");
352 if (!BitcodeBuffer)
353 PROTEUS_FATAL_ERROR("Expected non-null kernel bitcode");
354
355 KernelInfo.setModule(std::move(KernelModule));
356 KernelInfo.setBitcode(std::move(BitcodeBuffer));
358 << "Extract kernel module " << T.elapsed() << " ms\n");
359 }
360
361 Module &getModule(JITKernelInfo &KernelInfo) {
362 if (!KernelInfo.hasModule())
363 extractModuleAndBitcode(KernelInfo);
364
365 if (!KernelInfo.hasModule())
366 PROTEUS_FATAL_ERROR("Expected module in KernelInfo");
367
368 return KernelInfo.getModule();
369 }
370
371 MemoryBufferRef getBitcode(JITKernelInfo &KernelInfo) {
372 if (!KernelInfo.hasBitcode())
373 extractModuleAndBitcode(KernelInfo);
374
375 if (!KernelInfo.hasBitcode())
376 PROTEUS_FATAL_ERROR("Expected bitcode in KernelInfo");
377
378 return KernelInfo.getBitcode();
379 }
380
382 SmallVector<RuntimeConstant> &LambdaJitValuesVec) {
384 if (LR.empty()) {
385 KernelInfo.setLambdaCalleeInfo({});
386 return;
387 }
388
389 if (!KernelInfo.hasLambdaCalleeInfo()) {
390 Module &KernelModule = getModule(KernelInfo);
391 PROTEUS_DBG(Logger::logs("proteus")
392 << "=== LAMBDA MATCHING\n"
393 << "Caller trigger " << KernelInfo.getName() << " -> "
394 << demangle(KernelInfo.getName()) << "\n");
395
396 SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo;
397 for (auto &F : KernelModule.getFunctionList()) {
398 PROTEUS_DBG(Logger::logs("proteus")
399 << " Trying F " << demangle(F.getName().str()) << "\n ");
400 auto OptionalMapIt =
402 if (OptionalMapIt)
403 LambdaCalleeInfo.emplace_back(F.getName(),
404 OptionalMapIt.value()->first);
405 }
406
407 KernelInfo.setLambdaCalleeInfo(std::move(LambdaCalleeInfo));
408 }
409
410 for (auto &[FnName, LambdaType] : KernelInfo.getLambdaCalleeInfo()) {
411 const SmallVector<RuntimeConstant> &Values =
412 LR.getJitVariables(LambdaType);
413 LambdaJitValuesVec.insert(LambdaJitValuesVec.end(), Values.begin(),
414 Values.end());
415 }
416 }
417
418 void insertRegisterVar(const char *VarName, const void *Addr) {
419 VarNameToDevPtr[VarName] = Addr;
420 }
422 const char *ModuleId);
424 const char *ModuleId);
426 void registerFunction(void *Handle, void *Kernel, char *KernelName,
427 ArrayRef<RuntimeConstantInfo *> RCInfoArray);
428
429 void *CurHandle = nullptr;
430 std::unordered_map<std::string, FatbinWrapperT *> ModuleIdToFatBinary;
431 std::unordered_map<const void *, BinaryInfo> HandleToBinaryInfo;
432 SmallVector<std::string> GlobalLinkedModuleIds;
433 SmallPtrSet<void *, 8> GlobalLinkedBinaries;
434
435 bool containsJITKernelInfo(const void *Func) {
436 return JITKernelInfoMap.contains(Func);
437 }
438
439 std::optional<std::reference_wrapper<JITKernelInfo>>
440 getJITKernelInfo(const void *Func) {
442 return std::nullopt;
443 }
444 return JITKernelInfoMap[Func];
445 }
446
448 if (KernelInfo.hasStaticHash())
449 return KernelInfo.getStaticHash();
450
451 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
452
453 if (BinInfo.hasModuleHash()) {
454 KernelInfo.createStaticHash(BinInfo.getModuleHash());
455 return KernelInfo.getStaticHash();
456 }
457
458 HashT ModuleHash = static_cast<ImplT &>(*this).getModuleHash(BinInfo);
459
460 KernelInfo.createStaticHash(BinInfo.getModuleHash());
461 return KernelInfo.getStaticHash();
462 }
463
464 void finalize() {
465 if (Config::get().ProteusAsyncCompilation)
466 CompilerAsync::instance(Config::get().ProteusAsyncThreads)
468 }
469
470private:
471 //------------------------------------------------------------------
472 // Begin Methods implemented in the derived device engine class.
473 //------------------------------------------------------------------
474 void *resolveDeviceGlobalAddr(const void *Addr) {
475 return static_cast<ImplT &>(*this).resolveDeviceGlobalAddr(Addr);
476 }
477
478 void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize,
479 int BlockSize) {
480 static_cast<ImplT &>(*this).setLaunchBoundsForKernel(M, F, GridSize,
481 BlockSize);
482 }
483
484 void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
485 proteus::setKernelDims(M, GridDim, BlockDim);
486 }
487
488 DeviceError_t launchKernelFunction(KernelFunction_t KernelFunc, dim3 GridDim,
489 dim3 BlockDim, void **KernelArgs,
490 uint64_t ShmemSize,
491 DeviceStream_t Stream) {
492 TIMESCOPE(__FUNCTION__);
493 return static_cast<ImplT &>(*this).launchKernelFunction(
494 KernelFunc, GridDim, BlockDim, KernelArgs, ShmemSize, Stream);
495 }
496
497 void relinkGlobalsObject(MemoryBufferRef Object) {
498 TIMESCOPE(__FUNCTION__);
499 proteus::relinkGlobalsObject(Object, VarNameToDevPtr);
500 }
501
502 KernelFunction_t getKernelFunctionFromImage(StringRef KernelName,
503 const void *Image) {
504 TIMESCOPE(__FUNCTION__);
505 return static_cast<ImplT &>(*this).getKernelFunctionFromImage(KernelName,
506 Image);
507 }
508
509 //------------------------------------------------------------------
510 // End Methods implemented in the derived device engine class.
511 //------------------------------------------------------------------
512
513 void pruneIR(Module &M);
514
515 void internalize(Module &M, StringRef KernelName);
516
517 void replaceGlobalVariablesWithPointers(Module &M);
518
519protected:
521
526
529 std::string DeviceArch;
530 std::unordered_map<std::string, const void *> VarNameToDevPtr;
531
532 DenseMap<const void *, JITKernelInfo> JITKernelInfoMap;
533};
534
535template <typename ImplT> void JitEngineDevice<ImplT>::pruneIR(Module &M) {
536 TIMESCOPE("pruneIR");
538}
539
540template <typename ImplT>
541void JitEngineDevice<ImplT>::internalize(Module &M, StringRef KernelName) {
543}
544
545template <typename ImplT>
546void JitEngineDevice<ImplT>::replaceGlobalVariablesWithPointers(Module &M) {
547 TIMESCOPE(__FUNCTION__)
548
549 proteus::replaceGlobalVariablesWithPointers(M, VarNameToDevPtr);
550
551#if PROTEUS_ENABLE_DEBUG
552 Logger::logs("proteus") << "=== Linked M\n" << M << "=== End of Linked M\n";
553 if (verifyModule(M, &errs()))
555 "After linking, broken module found, JIT compilation aborted!");
556 else
557 Logger::logs("proteus") << "Module verified!\n";
558#endif
559}
560
561template <typename ImplT>
562typename DeviceTraits<ImplT>::DeviceError_t
564 JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs,
565 uint64_t ShmemSize, typename DeviceTraits<ImplT>::DeviceStream_t Stream) {
566 TIMESCOPE("compileAndRun");
567
568 // Lazy initialize the map of device global variables to device pointers by
569 // resolving the host address to the device address. For HIP it is fine to do
570 // this earlier (e.g., instertRegisterVar), but CUDA can't. So, we initialize
571 // this here the first time we need to compile a kernel.
572 static std::once_flag Flag;
573 std::call_once(Flag, [&]() {
574 for (auto &[GlobalName, HostAddr] : VarNameToDevPtr) {
575 void *DevPtr = resolveDeviceGlobalAddr(HostAddr);
576 VarNameToDevPtr.at(GlobalName) = DevPtr;
577 }
578 });
579
580 SmallVector<RuntimeConstant> RCVec =
581 getRuntimeConstantValues(KernelArgs, KernelInfo.getRCInfoArray());
582
583 SmallVector<RuntimeConstant> LambdaJitValuesVec;
584 getLambdaJitValues(KernelInfo, LambdaJitValuesVec);
585
586 HashT HashValue =
587 hash(getStaticHash(KernelInfo), RCVec, LambdaJitValuesVec, GridDim.x,
588 GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z);
589
590 typename DeviceTraits<ImplT>::KernelFunction_t KernelFunc =
591 CodeCache.lookup(HashValue);
592 if (KernelFunc)
593 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
594 ShmemSize, Stream);
595
596 // NOTE: we don't need a suffix to differentiate kernels, each specialization
597 // will be in its own module uniquely identify by HashValue. It exists only
598 // for debugging purposes to verify that the jitted kernel executes.
599 std::string Suffix = mangleSuffix(HashValue);
600 std::string KernelMangled = (KernelInfo.getName() + Suffix);
601
603 // If there device global variables, lookup the IR and codegen object
604 // before launching. Else, if there aren't device global variables, lookup
605 // the object and launch.
606
607 // TODO: Check for globals is very conservative and always re-builds from
608 // LLVM IR even if the Jit module does not use global variables. A better
609 // solution is to keep track of whether a kernel uses gvars (store a flag in
610 // the cache file?) and load the object in case it does not use any.
611 // TODO: Can we use RTC interfaces for fast linking on object files?
612 auto CacheBuf = StorageCache.lookup(HashValue);
613 if (CacheBuf) {
615 relinkGlobalsObject(CacheBuf->getMemBufferRef());
616
617 auto KernelFunc =
618 getKernelFunctionFromImage(KernelMangled, CacheBuf->getBufferStart());
619
620 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName(), RCVec);
621
622 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
623 ShmemSize, Stream);
624 }
625 }
626
627 MemoryBufferRef KernelBitcode = getBitcode(KernelInfo);
628 std::unique_ptr<MemoryBuffer> ObjBuf = nullptr;
629
630 if (Config::get().ProteusAsyncCompilation) {
631 auto &Compiler = CompilerAsync::instance(Config::get().ProteusAsyncThreads);
632 // If there is no compilation pending for the specialization, post the
633 // compilation task to the compiler.
634 if (!Compiler.isCompilationPending(HashValue)) {
635 PROTEUS_DBG(Logger::logs("proteus") << "Compile async for HashValue "
636 << HashValue.toString() << "\n");
637
638 Compiler.compile(CompilationTask{
639 KernelBitcode, HashValue, KernelInfo.getName(), Suffix, BlockDim,
640 GridDim, RCVec, KernelInfo.getLambdaCalleeInfo(), VarNameToDevPtr,
641 GlobalLinkedBinaries, DeviceArch,
642 /* CGOption */ Config::get().ProteusCodegen,
643 /* DumpIR */ Config::get().ProteusDumpLLVMIR,
644 /* RelinkGlobalsByCopy */ Config::get().ProteusRelinkGlobalsByCopy,
645 /*SpecializeArgs=*/Config::get().ProteusSpecializeArgs,
646 /*SpecializeDims=*/Config::get().ProteusSpecializeDims,
647 /*SpecializeLaunchBounds=*/
649 }
650
651 // Compilation is pending, try to get the compilation result buffer. If
652 // buffer is null, compilation is not done, so execute the AOT version
653 // directly.
654 ObjBuf = Compiler.takeCompilationResult(
655 HashValue, Config::get().ProteusAsyncTestBlocking);
656 if (!ObjBuf) {
657 return launchKernelDirect(KernelInfo.getKernel(), GridDim, BlockDim,
658 KernelArgs, ShmemSize, Stream);
659 }
660 } else {
661 // Process through synchronous compilation.
663 KernelBitcode, HashValue, KernelInfo.getName(), Suffix, BlockDim,
664 GridDim, RCVec, KernelInfo.getLambdaCalleeInfo(), VarNameToDevPtr,
665 GlobalLinkedBinaries, DeviceArch,
666 /* CGOption */ Config::get().ProteusCodegen,
667 /* DumpIR */ Config::get().ProteusDumpLLVMIR,
668 /* RelinkGlobalsByCopy */ Config::get().ProteusRelinkGlobalsByCopy,
669 /*SpecializeArgs=*/Config::get().ProteusSpecializeArgs,
670 /*SpecializeDims=*/Config::get().ProteusSpecializeDims,
671 /*SpecializeLaunchBounds=*/
673 }
674
675 if (!ObjBuf)
676 PROTEUS_FATAL_ERROR("Expected non-null object");
677
679 KernelMangled, ObjBuf->getBufferStart(),
680 Config::get().ProteusRelinkGlobalsByCopy, VarNameToDevPtr);
681
682 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName(), RCVec);
683 if (Config::get().ProteusUseStoredCache) {
684 StorageCache.store(HashValue, ObjBuf->getMemBufferRef());
685 }
686
687 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
688 ShmemSize, Stream);
689}
690
691template <typename ImplT>
694 const char *ModuleId) {
695 CurHandle = Handle;
696 PROTEUS_DBG(Logger::logs("proteus")
697 << "Register fatbinary Handle " << Handle << " FatbinWrapper "
698 << FatbinWrapper << " Binary " << (void *)FatbinWrapper->Binary
699 << " ModuleId " << ModuleId << "\n");
700 if (FatbinWrapper->PrelinkedFatbins) {
701 // This is RDC compilation, just insert the FatbinWrapper and ignore the
702 // ModuleId coming from the link.stub.
703 HandleToBinaryInfo.emplace(Handle, BinaryInfo{FatbinWrapper, {}});
704
705 // Initialize GlobalLinkedBinaries with prelinked fatbins.
706 void *Ptr = FatbinWrapper->PrelinkedFatbins[0];
707 for (int I = 0; Ptr != nullptr;
708 ++I, Ptr = FatbinWrapper->PrelinkedFatbins[I]) {
709 PROTEUS_DBG(Logger::logs("proteus")
710 << "I " << I << " PrelinkedFatbin " << Ptr << "\n");
711 GlobalLinkedBinaries.insert(Ptr);
712 }
713 } else {
714 // This is non-RDC compilation, associate the ModuleId of the JIT bitcode in
715 // the module with the FatbinWrapper.
716 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
717 HandleToBinaryInfo.emplace(Handle, BinaryInfo{FatbinWrapper, {ModuleId}});
718 }
719}
720
721template <typename ImplT> void JitEngineDevice<ImplT>::registerFatBinaryEnd() {
722 PROTEUS_DBG(Logger::logs("proteus") << "Register fatbinary end\n");
723 // Erase linked binaries for which we have LLVM IR code, those binaries are
724 // stored in the ModuleIdToFatBinary map.
725 for (auto &[ModuleId, FatbinWrapper] : ModuleIdToFatBinary)
726 GlobalLinkedBinaries.erase((void *)FatbinWrapper->Binary);
727
728 CurHandle = nullptr;
729}
730
731template <typename ImplT>
733 void *Handle, void *Kernel, char *KernelName,
734 ArrayRef<RuntimeConstantInfo *> RCInfoArray) {
735 PROTEUS_DBG(Logger::logs("proteus") << "Register function " << Kernel
736 << " To Handle " << Handle << "\n");
737 // NOTE: HIP RDC might call multiple times the registerFunction for the same
738 // kernel, which has weak linkage, when it comes from different translation
739 // units. Either the first or the second call can prevail and should be
740 // equivalent. We let the first one prevail.
741 if (JITKernelInfoMap.contains(Kernel)) {
742 PROTEUS_DBG(Logger::logs("proteus")
743 << "Warning: duplicate register function for kernel " +
744 std::string(KernelName)
745 << "\n");
746 return;
747 }
748
749 if (!HandleToBinaryInfo.count(Handle))
750 PROTEUS_FATAL_ERROR("Expected Handle in map");
751 BinaryInfo &BinInfo = HandleToBinaryInfo[Handle];
752
753 JITKernelInfoMap[Kernel] =
755}
756
757template <typename ImplT>
759 const char *ModuleId) {
760 PROTEUS_DBG(Logger::logs("proteus")
761 << "Register linked binary FatBinary " << FatbinWrapper
762 << " Binary " << (void *)FatbinWrapper->Binary << " ModuleId "
763 << ModuleId << "\n");
764 if (CurHandle) {
765 if (!HandleToBinaryInfo.count(CurHandle))
766 PROTEUS_FATAL_ERROR("Expected CurHandle in map");
767
768 HandleToBinaryInfo[CurHandle].addModuleId(ModuleId);
769 } else
770 GlobalLinkedModuleIds.push_back(ModuleId);
771
772 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
773}
774
775} // namespace proteus
776
777#endif
void const char * ModuleId
Definition CompilerInterfaceDevice.cpp:31
void * FatbinWrapper
Definition CompilerInterfaceDevice.cpp:30
const char * VarName
Definition CompilerInterfaceDevice.cpp:20
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
void * Kernel
Definition CompilerInterfaceDevice.cpp:50
ArrayRef< RuntimeConstantInfo * > RCInfoArray
Definition CompilerInterfaceHost.cpp:24
#define PROTEUS_DBG(x)
Definition Debug.h:10
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
void getLambdaJitValues(StringRef FnName, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineHost.cpp:271
#define TIMESCOPE(x)
Definition TimeTracing.hpp:64
#define PROTEUS_TIMER_OUTPUT(x)
Definition TimeTracing.hpp:57
Definition JitEngineDevice.hpp:79
FatbinWrapperT * getFatbinWrapper() const
Definition JitEngineDevice.hpp:99
void setExtractedModules(SmallVector< std::unique_ptr< Module > > &Modules)
Definition JitEngineDevice.hpp:147
MemoryBufferRef getDeviceBinary()
Definition JitEngineDevice.hpp:171
bool hasModuleHash() const
Definition JitEngineDevice.hpp:151
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.hpp:101
Module & getLinkedModule()
Definition JitEngineDevice.hpp:104
auto & getModuleIds()
Definition JitEngineDevice.hpp:184
bool hasLinkedModule() const
Definition JitEngineDevice.hpp:103
bool hasDeviceBinary()
Definition JitEngineDevice.hpp:170
const SmallVector< std::reference_wrapper< Module > > getExtractedModules() const
Definition JitEngineDevice.hpp:139
void updateModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:154
HashT getModuleHash() const
Definition JitEngineDevice.hpp:152
bool hasExtractedModules() const
Definition JitEngineDevice.hpp:137
void addModuleId(const char *ModuleId)
Definition JitEngineDevice.hpp:180
CallGraph & getCallGraph()
Definition JitEngineDevice.hpp:161
BinaryInfo(FatbinWrapperT *FatbinWrapper, SmallVector< std::string > &&LinkedModuleIds)
Definition JitEngineDevice.hpp:92
void setModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:153
void setDeviceBinary(std::unique_ptr< MemoryBuffer > DeviceBinaryBuffer)
Definition JitEngineDevice.hpp:176
Definition CompilationTask.hpp:17
static CompilerAsync & instance(int NumThreads)
Definition CompilerAsync.hpp:49
void joinAllThreads()
Definition CompilerAsync.hpp:86
std::unique_ptr< MemoryBuffer > compile(CompilationTask &&CT)
Definition CompilerSync.hpp:21
static CompilerSync & instance()
Definition CompilerSync.hpp:16
CodegenOption ProteusCodegen
Definition Config.hpp:129
static Config & get()
Definition Config.hpp:112
bool ProteusRelinkGlobalsByCopy
Definition Config.hpp:123
bool ProteusSpecializeDims
Definition Config.hpp:120
bool ProteusDumpLLVMIR
Definition Config.hpp:122
bool ProteusUseStoredCache
Definition Config.hpp:117
bool ProteusSpecializeArgs
Definition Config.hpp:119
bool ProteusSpecializeLaunchBounds
Definition Config.hpp:118
Definition Func.hpp:19
Definition Hashing.hpp:19
std::string toString() const
Definition Hashing.hpp:27
Definition JitEngineDevice.hpp:187
bool hasBitcode()
Definition JitEngineDevice.hpp:222
const std::string & getName() const
Definition JitEngineDevice.hpp:213
void createStaticHash(HashT ModuleHash)
Definition JitEngineDevice.hpp:230
JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.hpp:200
void * getKernel() const
Definition JitEngineDevice.hpp:208
bool hasModule() const
Definition JitEngineDevice.hpp:215
const HashT getStaticHash() const
Definition JitEngineDevice.hpp:229
BinaryInfo & getBinaryInfo() const
Definition JitEngineDevice.hpp:217
ArrayRef< RuntimeConstantInfo * > getRCInfoArray() const
Definition JitEngineDevice.hpp:214
Module & getModule() const
Definition JitEngineDevice.hpp:216
void setModule(std::unique_ptr< llvm::Module > Mod)
Definition JitEngineDevice.hpp:218
void setLambdaCalleeInfo(SmallVector< std::pair< std::string, StringRef > > &&LambdaInfo)
Definition JitEngineDevice.hpp:237
const auto & getLambdaCalleeInfo()
Definition JitEngineDevice.hpp:236
bool hasLambdaCalleeInfo()
Definition JitEngineDevice.hpp:235
MemoryBufferRef getBitcode()
Definition JitEngineDevice.hpp:226
bool hasStaticHash() const
Definition JitEngineDevice.hpp:228
std::unique_ptr< LLVMContext > & getLLVMContext()
Definition JitEngineDevice.hpp:212
void setBitcode(std::unique_ptr< MemoryBuffer > ExtractedBitcode)
Definition JitEngineDevice.hpp:223
Definition JitCache.hpp:32
void printStats()
Definition JitCache.hpp:67
Definition JitEngineDevice.hpp:245
JitCache< KernelFunction_t > CodeCache
Definition JitEngineDevice.hpp:527
MemoryBufferRef getBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:371
~JitEngineDevice()
Definition JitEngineDevice.hpp:522
typename DeviceTraits< ImplT >::DeviceError_t DeviceError_t
Definition JitEngineDevice.hpp:248
void extractModuleAndBitcode(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:332
void registerLinkedBinary(FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:758
JitEngineDevice()
Definition JitEngineDevice.hpp:520
std::unordered_map< const void *, BinaryInfo > HandleToBinaryInfo
Definition JitEngineDevice.hpp:431
DenseMap< const void *, JITKernelInfo > JITKernelInfoMap
Definition JitEngineDevice.hpp:532
typename DeviceTraits< ImplT >::DeviceStream_t DeviceStream_t
Definition JitEngineDevice.hpp:249
void registerFatBinaryEnd()
Definition JitEngineDevice.hpp:721
Module & getModule(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:361
JitStorageCache< KernelFunction_t > StorageCache
Definition JitEngineDevice.hpp:528
bool containsJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:435
std::optional< std::reference_wrapper< JITKernelInfo > > getJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:440
std::pair< std::unique_ptr< Module >, std::unique_ptr< MemoryBuffer > > extractKernelModule(BinaryInfo &BinInfo, StringRef KernelName, LLVMContext &Ctx)
Definition JitEngineDevice.hpp:258
SmallPtrSet< void *, 8 > GlobalLinkedBinaries
Definition JitEngineDevice.hpp:433
void registerFunction(void *Handle, void *Kernel, char *KernelName, ArrayRef< RuntimeConstantInfo * > RCInfoArray)
Definition JitEngineDevice.hpp:732
std::unordered_map< std::string, const void * > VarNameToDevPtr
Definition JitEngineDevice.hpp:530
void finalize()
Definition JitEngineDevice.hpp:464
void * CurHandle
Definition JitEngineDevice.hpp:429
void registerFatBinary(void *Handle, FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:692
DeviceError_t compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, typename DeviceTraits< ImplT >::DeviceStream_t Stream)
Definition JitEngineDevice.hpp:563
std::unordered_map< std::string, FatbinWrapperT * > ModuleIdToFatBinary
Definition JitEngineDevice.hpp:430
typename DeviceTraits< ImplT >::KernelFunction_t KernelFunction_t
Definition JitEngineDevice.hpp:250
void insertRegisterVar(const char *VarName, const void *Addr)
Definition JitEngineDevice.hpp:418
SmallVector< std::string > GlobalLinkedModuleIds
Definition JitEngineDevice.hpp:432
std::string DeviceArch
Definition JitEngineDevice.hpp:529
HashT getStaticHash(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:447
void getLambdaJitValues(JITKernelInfo &KernelInfo, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineDevice.hpp:381
Definition JitEngine.hpp:33
Definition JitStorageCache.hpp:36
void printStats()
Definition JitStorageCache.hpp:65
Definition LambdaRegistry.hpp:19
std::optional< DenseMap< StringRef, SmallVector< RuntimeConstant > >::iterator > matchJitVariableMap(StringRef FnName)
Definition LambdaRegistry.hpp:27
const SmallVector< RuntimeConstant > & getJitVariables(StringRef LambdaTypeRef)
Definition LambdaRegistry.hpp:74
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:21
bool empty()
Definition LambdaRegistry.hpp:78
static llvm::raw_ostream & outs(const std::string &Name)
Definition Logger.hpp:25
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:19
Definition TimeTracing.hpp:36
uint64_t elapsed()
Definition TimeTracing.hpp:45
Definition Dispatcher.cpp:14
std::unique_ptr< Module > cloneKernelFromModules(ArrayRef< std::reference_wrapper< Module > > Mods, StringRef EntryName)
Definition Cloning.h:496
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:73
cudaError_t launchKernelDirect(void *KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:20
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:51
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, const void * > &VarNameToDevPtr)
Definition CoreDeviceCUDA.hpp:27
HashT hashCombine(HashT A, HashT B)
Definition Hashing.hpp:68
void pruneIR(Module &M, bool UnsetExternallyInitialized=true)
Definition CoreLLVM.hpp:202
std::string toString(CodegenOption Option)
Definition Config.hpp:23
void * resolveDeviceGlobalAddr(const void *Addr)
Definition CoreDeviceCUDA.hpp:12
void internalize(Module &M, StringRef PreserveFunctionName)
Definition CoreLLVM.hpp:237
void runCleanupPassPipeline(Module &M)
Definition CoreLLVM.hpp:178
std::unique_ptr< Module > linkModules(LLVMContext &Ctx, SmallVector< std::unique_ptr< Module > > LinkedModules)
Definition CoreLLVM.hpp:161
Definition Hashing.hpp:94
Definition JitEngineDevice.hpp:243
Definition JitEngineDevice.hpp:72
const char * Binary
Definition JitEngineDevice.hpp:75
void ** PrelinkedFatbins
Definition JitEngineDevice.hpp:76
int32_t Magic
Definition JitEngineDevice.hpp:73
int32_t Version
Definition JitEngineDevice.hpp:74