Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitEngineDevice.hpp
Go to the documentation of this file.
1//===-- JitEngineDevice.cpp -- Base JIT Engine Device header impl. --===//
2//
3// Part of the Proteus Project, under the Apache License v2.0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//===----------------------------------------------------------------------===//
10
11#ifndef PROTEUS_JITENGINEDEVICE_HPP
12#define PROTEUS_JITENGINEDEVICE_HPP
13
14#include <cstdint>
15#include <functional>
16#include <llvm/ADT/SmallPtrSet.h>
17#include <llvm/Analysis/CallGraph.h>
18#include <memory>
19#include <optional>
20#include <string>
21
22#include <llvm/ADT/SmallVector.h>
23#include <llvm/ADT/StringRef.h>
24#include <llvm/Analysis/TargetTransformInfo.h>
25#include <llvm/Bitcode/BitcodeWriter.h>
26#include <llvm/CodeGen/CommandFlags.h>
27#include <llvm/CodeGen/MachineModuleInfo.h>
28#include <llvm/Config/llvm-config.h>
29#include <llvm/Demangle/Demangle.h>
30#include <llvm/ExecutionEngine/Orc/ThreadSafeModule.h>
31#include <llvm/IR/Constants.h>
32#include <llvm/IR/GlobalVariable.h>
33#include <llvm/IR/Instruction.h>
34#include <llvm/IR/Instructions.h>
35#include <llvm/IR/LLVMContext.h>
36#include <llvm/IR/LegacyPassManager.h>
37#include <llvm/IR/Module.h>
38#include <llvm/IR/ReplaceConstant.h>
39#include <llvm/IR/Type.h>
40#include <llvm/IR/Verifier.h>
41#include <llvm/IRReader/IRReader.h>
42#include <llvm/Linker/Linker.h>
43#include <llvm/MC/TargetRegistry.h>
44#include <llvm/Object/ELFObjectFile.h>
45#include <llvm/Passes/PassBuilder.h>
46#include <llvm/Support/Error.h>
47#include <llvm/Support/MemoryBuffer.h>
48#include <llvm/Support/MemoryBufferRef.h>
49#include <llvm/Target/TargetMachine.h>
50#include <llvm/Transforms/IPO/Internalize.h>
51#include <llvm/Transforms/Utils/Cloning.h>
52#include <llvm/Transforms/Utils/ModuleUtils.h>
53
58#include "proteus/CoreLLVM.hpp"
59#include "proteus/Debug.h"
60#include "proteus/Hashing.hpp"
61#include "proteus/JitCache.hpp"
62#include "proteus/JitEngine.hpp"
65#include "proteus/Utils.h"
66
67namespace proteus {
68
69using namespace llvm;
70
72 int32_t Magic;
73 int32_t Version;
74 const char *Binary;
76};
77
79private:
80 FatbinWrapperT *FatbinWrapper;
81 SmallVector<std::string> LinkedModuleIds;
82 std::unique_ptr<Module> ExtractedModule;
83 std::optional<HashT> ExtractedModuleHash;
84 std::optional<CallGraph> ModuleCallGraph;
85
86public:
87 BinaryInfo() = default;
88 BinaryInfo(FatbinWrapperT *FatbinWrapper,
89 SmallVector<std::string> &&LinkedModuleIds)
90 : FatbinWrapper(FatbinWrapper), LinkedModuleIds(LinkedModuleIds),
91 ModuleCallGraph(std::nullopt) {}
92
94
95 bool hasModule() const { return (ExtractedModule != nullptr); }
96 Module &getModule() const { return *ExtractedModule; }
97 void setModule(std::unique_ptr<Module> Module) {
98 ExtractedModule = std::move(Module);
99 }
100
101 bool hasModuleHash() const { return ExtractedModuleHash.has_value(); }
102 HashT getModuleHash() const { return ExtractedModuleHash.value(); }
103 void setModuleHash(HashT HashValue) { ExtractedModuleHash = HashValue; }
104 void updateModuleHash(HashT HashValue) {
105 if (ExtractedModuleHash)
106 ExtractedModuleHash = hashCombine(ExtractedModuleHash.value(), HashValue);
107 else
108 ExtractedModuleHash = HashValue;
109 }
110
111 CallGraph &getCallGraph() {
112 if (!ModuleCallGraph.has_value()) {
113 ModuleCallGraph.emplace(CallGraph(*ExtractedModule));
114 }
115 return ModuleCallGraph.value();
116 }
117
118 void addModuleId(const char *ModuleId) {
119 LinkedModuleIds.push_back(ModuleId);
120 }
121
122 auto &getModuleIds() { return LinkedModuleIds; }
123};
124
126 std::optional<void *> Kernel;
127 std::string Name;
128 SmallVector<int32_t> RCTypes;
129 SmallVector<int32_t> RCIndices;
130 std::optional<std::unique_ptr<Module>> ExtractedModule;
131 std::optional<std::reference_wrapper<BinaryInfo>> BinInfo;
132 std::optional<HashT> StaticHash;
133 std::optional<SmallVector<std::pair<std::string, StringRef>>>
134 LambdaCalleeInfo;
135
136public:
137 JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name,
138 int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs)
139 : Kernel(Kernel), Name(Name),
140 RCTypes{ArrayRef{RCTypes, static_cast<size_t>(NumRCs)}},
141 RCIndices{ArrayRef{RCIndices, static_cast<size_t>(NumRCs)}},
142 ExtractedModule(std::nullopt), BinInfo(BinInfo),
143 LambdaCalleeInfo(std::nullopt) {}
144
145 JITKernelInfo() = default;
146 void *getKernel() const {
147 assert(Kernel.has_value() && "Expected Kernel is inited");
148 return Kernel.value();
149 }
150 const std::string &getName() const { return Name; }
151 const auto &getRCIndices() const { return RCIndices; }
152 const auto &getRCTypes() const { return RCTypes; }
153 bool hasModule() const { return ExtractedModule.has_value(); }
154 Module &getModule() const { return *ExtractedModule->get(); }
155 BinaryInfo &getBinaryInfo() const { return BinInfo.value(); }
156 void setModule(std::unique_ptr<llvm::Module> Mod) {
157 ExtractedModule = std::move(Mod);
158 }
159 bool hasStaticHash() const { return StaticHash.has_value(); }
160 const HashT getStaticHash() const { return StaticHash.value(); }
161 void createStaticHash(HashT ModuleHash) {
162 StaticHash = hash(Name);
163 StaticHash = hashCombine(StaticHash.value(), ModuleHash);
164 }
165
166 bool hasLambdaCalleeInfo() { return LambdaCalleeInfo.has_value(); }
167 const auto &getLambdaCalleeInfo() { return LambdaCalleeInfo.value(); }
169 SmallVector<std::pair<std::string, StringRef>> &&LambdaInfo) {
170 LambdaCalleeInfo = std::move(LambdaInfo);
171 }
172};
173
174template <typename ImplT> struct DeviceTraits;
175
176template <typename ImplT> class JitEngineDevice : public JitEngine {
177
178private:
179 // LLVMContext needs to destroy after all associated Module objects have been
180 // destroyed. Declared first to destroy last.
181 LLVMContext Ctx;
182
183public:
187
189 compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim,
190 void **KernelArgs, uint64_t ShmemSize,
192
193 Module &getModule(JITKernelInfo &KernelInfo) {
194 TIMESCOPE(__FUNCTION__)
195
196 if (KernelInfo.hasModule())
197 return KernelInfo.getModule();
198
199 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
200
201 if (!BinInfo.hasModule()) {
202 std::unique_ptr<Module> ExtractedModule =
203 static_cast<ImplT &>(*this).extractModule(BinInfo);
204
205 pruneIR(*ExtractedModule);
206 runCleanupPassPipeline(*ExtractedModule);
207
208 BinInfo.setModule(std::move(ExtractedModule));
209 }
210
211 auto &BinModule = BinInfo.getModule();
212 std::unique_ptr<Module> KernelModule{nullptr};
213
214 if (Config.PROTEUS_USE_LIGHTWEIGHT_KERNEL_CLONE) {
215 KernelModule = std::move(proteus::cloneKernelFromModule(
216 BinModule, getLLVMContext(), KernelInfo.getName(),
217 BinInfo.getCallGraph()));
218 } else {
219 KernelModule = llvm::CloneModule(BinModule);
220 }
221
222 internalize(*KernelModule, KernelInfo.getName());
223 runCleanupPassPipeline(*KernelModule);
224
225 KernelInfo.setModule(std::move(KernelModule));
226 return KernelInfo.getModule();
227 }
228
230 SmallVector<RuntimeConstant> &LambdaJitValuesVec) {
232 if (LR.empty()) {
233 KernelInfo.setLambdaCalleeInfo({});
234 return;
235 }
236
237 if (!KernelInfo.hasLambdaCalleeInfo()) {
238 Module &KernelModule = getModule(KernelInfo);
239 PROTEUS_DBG(Logger::logs("proteus")
240 << "=== LAMBDA MATCHING\n"
241 << "Caller trigger " << KernelInfo.getName() << " -> "
242 << demangle(KernelInfo.getName()) << "\n");
243
244 SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo;
245 for (auto &F : KernelModule.getFunctionList()) {
246 PROTEUS_DBG(Logger::logs("proteus")
247 << " Trying F " << demangle(F.getName().str()) << "\n ");
248 auto OptionalMapIt =
250 if (OptionalMapIt)
251 LambdaCalleeInfo.emplace_back(F.getName(),
252 OptionalMapIt.value()->first);
253 }
254
255 KernelInfo.setLambdaCalleeInfo(std::move(LambdaCalleeInfo));
256 }
257
258 for (auto &[FnName, LambdaType] : KernelInfo.getLambdaCalleeInfo()) {
259 const SmallVector<RuntimeConstant> &Values =
260 LR.getJitVariables(LambdaType);
261 LambdaJitValuesVec.insert(LambdaJitValuesVec.end(), Values.begin(),
262 Values.end());
263 }
264 }
265
266 void insertRegisterVar(const char *VarName, const void *Addr) {
267 VarNameToDevPtr[VarName] = Addr;
268 }
270 const char *ModuleId);
272 const char *ModuleId);
274 void registerFunction(void *Handle, void *Kernel, char *KernelName,
275 int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs);
276
277 void *CurHandle = nullptr;
278 std::unordered_map<std::string, FatbinWrapperT *> ModuleIdToFatBinary;
279 std::unordered_map<const void *, BinaryInfo> HandleToBinaryInfo;
280 SmallVector<std::string> GlobalLinkedModuleIds;
281 SmallPtrSet<void *, 8> GlobalLinkedBinaries;
282
283 bool containsJITKernelInfo(const void *Func) {
284 return JITKernelInfoMap.contains(Func);
285 }
286
287 std::optional<std::reference_wrapper<JITKernelInfo>>
288 getJITKernelInfo(const void *Func) {
289 if (!containsJITKernelInfo(Func)) {
290 return std::nullopt;
291 }
292 return JITKernelInfoMap[Func];
293 }
294
296 if (KernelInfo.hasStaticHash())
297 return KernelInfo.getStaticHash();
298
299 BinaryInfo &BinInfo = KernelInfo.getBinaryInfo();
300
301 if (BinInfo.hasModuleHash()) {
302 KernelInfo.createStaticHash(BinInfo.getModuleHash());
303 return KernelInfo.getStaticHash();
304 }
305
306 HashT ModuleHash = static_cast<ImplT &>(*this).getModuleHash(BinInfo);
307
308 KernelInfo.createStaticHash(BinInfo.getModuleHash());
309 return KernelInfo.getStaticHash();
310 }
311
312 void finalize() {
313 if (Config.PROTEUS_ASYNC_COMPILATION)
314 CompilerAsync::instance(Config.PROTEUS_ASYNC_THREADS).joinAllThreads();
315 }
316
317private:
318 //------------------------------------------------------------------
319 // Begin Methods implemented in the derived device engine class.
320 //------------------------------------------------------------------
321 void *resolveDeviceGlobalAddr(const void *Addr) {
322 return static_cast<ImplT &>(*this).resolveDeviceGlobalAddr(Addr);
323 }
324
325 void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize,
326 int BlockSize) {
327 static_cast<ImplT &>(*this).setLaunchBoundsForKernel(M, F, GridSize,
328 BlockSize);
329 }
330
331 void setKernelDims(Module &M, dim3 &GridDim, dim3 &BlockDim) {
332 proteus::setKernelDims(M, GridDim, BlockDim);
333 }
334
335 DeviceError_t launchKernelFunction(KernelFunction_t KernelFunc, dim3 GridDim,
336 dim3 BlockDim, void **KernelArgs,
337 uint64_t ShmemSize,
338 DeviceStream_t Stream) {
339 TIMESCOPE(__FUNCTION__);
340 return static_cast<ImplT &>(*this).launchKernelFunction(
341 KernelFunc, GridDim, BlockDim, KernelArgs, ShmemSize, Stream);
342 }
343
344 void relinkGlobalsObject(MemoryBufferRef Object) {
345 TIMESCOPE(__FUNCTION__);
346 proteus::relinkGlobalsObject(Object, VarNameToDevPtr);
347 }
348
349 std::unique_ptr<MemoryBuffer> codegenObject(Module &M, StringRef DeviceArch) {
350 return static_cast<ImplT &>(*this).codegenObject(M, DeviceArch);
351 }
352
353 KernelFunction_t getKernelFunctionFromImage(StringRef KernelName,
354 const void *Image) {
355 TIMESCOPE(__FUNCTION__);
356 return static_cast<ImplT &>(*this).getKernelFunctionFromImage(KernelName,
357 Image);
358 }
359
360 //------------------------------------------------------------------
361 // End Methods implemented in the derived device engine class.
362 //------------------------------------------------------------------
363
364 void pruneIR(Module &M);
365
366 void internalize(Module &M, StringRef KernelName);
367
368 void specializeIR(Module &M, StringRef FnName, StringRef Suffix,
369 dim3 &BlockDim, dim3 &GridDim,
370 const SmallVector<int32_t> &RCIndices,
371 const SmallVector<RuntimeConstant> &RCVec);
372
373 void replaceGlobalVariablesWithPointers(Module &M);
374
375protected:
377
382
385 std::string DeviceArch;
386 std::unordered_map<std::string, const void *> VarNameToDevPtr;
387 std::unique_ptr<Module>
388 linkJitModule(SmallVector<std::unique_ptr<Module>> &LinkedModules,
389 std::unique_ptr<Module> LTOModule = nullptr);
390
391 LLVMContext &getLLVMContext() { return Ctx; }
392
393 DenseMap<const void *, JITKernelInfo> JITKernelInfoMap;
394};
395
396template <typename ImplT>
398 Module &M, StringRef FnName, StringRef Suffix, dim3 &BlockDim,
399 dim3 &GridDim, const SmallVector<int32_t> &RCIndices,
400 const SmallVector<RuntimeConstant> &RCVec) {
401 TIMESCOPE("specializeIR");
402
403 proteus::specializeIR(M, FnName, Suffix, BlockDim, GridDim, RCIndices, RCVec,
404 Config.PROTEUS_SPECIALIZE_ARGS,
405 Config.PROTEUS_SPECIALIZE_DIMS,
406 Config.PROTEUS_SET_LAUNCH_BOUNDS);
407
408#if PROTEUS_ENABLE_DEBUG
409 Logger::logs("proteus") << "=== Final Module\n"
410 << M << "=== End Final Module\n";
411 if (verifyModule(M, &errs()))
412 PROTEUS_FATAL_ERROR("Broken module found, JIT compilation aborted!");
413 else
414 Logger::logs("proteus") << "Module verified!\n";
415#endif
416}
417
418template <typename ImplT> void JitEngineDevice<ImplT>::pruneIR(Module &M) {
419 TIMESCOPE("pruneIR");
421}
422
423template <typename ImplT>
424void JitEngineDevice<ImplT>::internalize(Module &M, StringRef KernelName) {
426}
427
428template <typename ImplT>
429void JitEngineDevice<ImplT>::replaceGlobalVariablesWithPointers(Module &M) {
430 TIMESCOPE(__FUNCTION__)
431
432 proteus::replaceGlobalVariablesWithPointers(M, VarNameToDevPtr);
433
434#if PROTEUS_ENABLE_DEBUG
435 Logger::logs("proteus") << "=== Linked M\n" << M << "=== End of Linked M\n";
436 if (verifyModule(M, &errs()))
438 "After linking, broken module found, JIT compilation aborted!");
439 else
440 Logger::logs("proteus") << "Module verified!\n";
441#endif
442}
443
444template <typename ImplT>
445typename DeviceTraits<ImplT>::DeviceError_t
447 JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs,
448 uint64_t ShmemSize, typename DeviceTraits<ImplT>::DeviceStream_t Stream) {
449 TIMESCOPE("compileAndRun");
450
451 // Lazy initialize the map of device global variables to device pointers by
452 // resolving the host address to the device address. For HIP it is fine to do
453 // this earlier (e.g., instertRegisterVar), but CUDA can't. So, we initialize
454 // this here the first time we need to compile a kernel.
455 static std::once_flag Flag;
456 std::call_once(Flag, [&]() {
457 for (auto &[GlobalName, HostAddr] : VarNameToDevPtr) {
458 void *DevPtr = resolveDeviceGlobalAddr(HostAddr);
459 VarNameToDevPtr.at(GlobalName) = DevPtr;
460 }
461 });
462
463 SmallVector<RuntimeConstant> RCVec;
464 SmallVector<RuntimeConstant> LambdaJitValuesVec;
465
466 getRuntimeConstantValues(KernelArgs, KernelInfo.getRCIndices(),
467 KernelInfo.getRCTypes(), RCVec);
468 getLambdaJitValues(KernelInfo, LambdaJitValuesVec);
469
470 HashT HashValue =
471 hash(getStaticHash(KernelInfo), RCVec, LambdaJitValuesVec, GridDim.x,
472 GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z);
473
474 typename DeviceTraits<ImplT>::KernelFunction_t KernelFunc =
475 CodeCache.lookup(HashValue);
476 if (KernelFunc)
477 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
478 ShmemSize, Stream);
479
480 // NOTE: we don't need a suffix to differentiate kernels, each specialization
481 // will be in its own module uniquely identify by HashValue. It exists only
482 // for debugging purposes to verify that the jitted kernel executes.
483 std::string Suffix = mangleSuffix(HashValue);
484 std::string KernelMangled = (KernelInfo.getName() + Suffix);
485
486 if (Config.PROTEUS_USE_STORED_CACHE) {
487 // If there device global variables, lookup the IR and codegen object
488 // before launching. Else, if there aren't device global variables, lookup
489 // the object and launch.
490
491 // TODO: Check for globals is very conservative and always re-builds from
492 // LLVM IR even if the Jit module does not use global variables. A better
493 // solution is to keep track of whether a kernel uses gvars (store a flag in
494 // the cache file?) and load the object in case it does not use any.
495 // TODO: Can we use RTC interfaces for fast linking on object files?
496 auto CacheBuf = StorageCache.lookup(HashValue);
497 if (CacheBuf) {
498 if (!Config.PROTEUS_RELINK_GLOBALS_BY_COPY)
499 relinkGlobalsObject(CacheBuf->getMemBufferRef());
500
501 auto KernelFunc =
502 getKernelFunctionFromImage(KernelMangled, CacheBuf->getBufferStart());
503
504 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName(), RCVec);
505
506 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
507 ShmemSize, Stream);
508 }
509 }
510
511 Module &KernelModule = getModule(KernelInfo);
512 std::unique_ptr<MemoryBuffer> ObjBuf = nullptr;
513
514 if (Config.PROTEUS_ASYNC_COMPILATION) {
515 auto &Compiler = CompilerAsync::instance(Config.PROTEUS_ASYNC_THREADS);
516 // If there is no compilation pending for the specialization, post the
517 // compilation task to the compiler.
518 if (!Compiler.isCompilationPending(HashValue)) {
519 PROTEUS_DBG(Logger::logs("proteus") << "Compile async for HashValue "
520 << HashValue.toString() << "\n");
521
522 Compiler.compile(CompilationTask{
523 KernelModule, HashValue, KernelInfo.getName(), Suffix, BlockDim,
524 GridDim, KernelInfo.getRCIndices(), RCVec,
525 KernelInfo.getLambdaCalleeInfo(), VarNameToDevPtr,
526 GlobalLinkedBinaries, DeviceArch,
527 /* UseRTC */ Config.PROTEUS_USE_HIP_RTC_CODEGEN,
528 /* DumpIR */ Config.PROTEUS_DUMP_LLVM_IR,
529 /* RelinkGlobalsByCopy */ Config.PROTEUS_RELINK_GLOBALS_BY_COPY,
530 /*SpecializeArgs=*/Config.PROTEUS_SPECIALIZE_ARGS,
531 /*SpecializeDims=*/Config.PROTEUS_SPECIALIZE_DIMS,
532 /*SpecializeLaunchBounds=*/Config.PROTEUS_SET_LAUNCH_BOUNDS});
533 }
534
535 // Compilation is pending, try to get the compilation result buffer. If
536 // buffer is null, compilation is not done, so execute the AOT version
537 // directly.
538 ObjBuf = Compiler.takeCompilationResult(HashValue,
539 Config.PROTEUS_ASYNC_TEST_BLOCKING);
540 if (!ObjBuf) {
541 return launchKernelDirect(KernelInfo.getKernel(), GridDim, BlockDim,
542 KernelArgs, ShmemSize, Stream);
543 }
544 } else {
545 // Process through synchronous compilation.
547 KernelModule, HashValue, KernelInfo.getName(), Suffix, BlockDim,
548 GridDim, KernelInfo.getRCIndices(), RCVec,
549 KernelInfo.getLambdaCalleeInfo(), VarNameToDevPtr, GlobalLinkedBinaries,
550 DeviceArch,
551 /* UseRTC */ Config.PROTEUS_USE_HIP_RTC_CODEGEN,
552 /* DumpIR */ Config.PROTEUS_DUMP_LLVM_IR,
553 /* RelinkGlobalsByCopy */ Config.PROTEUS_RELINK_GLOBALS_BY_COPY,
554 /*SpecializeArgs=*/Config.PROTEUS_SPECIALIZE_ARGS,
555 /*SpecializeDims=*/Config.PROTEUS_SPECIALIZE_DIMS,
556 /*SpecializeLaunchBounds=*/Config.PROTEUS_SET_LAUNCH_BOUNDS});
557 }
558
560 KernelMangled, ObjBuf->getBufferStart(),
561 Config.PROTEUS_RELINK_GLOBALS_BY_COPY, VarNameToDevPtr);
562
563 CodeCache.insert(HashValue, KernelFunc, KernelInfo.getName(), RCVec);
564 if (Config.PROTEUS_USE_STORED_CACHE) {
565 StorageCache.store(HashValue, ObjBuf->getMemBufferRef());
566 }
567
568 return launchKernelFunction(KernelFunc, GridDim, BlockDim, KernelArgs,
569 ShmemSize, Stream);
570}
571
572template <typename ImplT>
575 const char *ModuleId) {
576 CurHandle = Handle;
577 PROTEUS_DBG(Logger::logs("proteus")
578 << "Register fatbinary Handle " << Handle << " FatbinWrapper "
579 << FatbinWrapper << " Binary " << (void *)FatbinWrapper->Binary
580 << " ModuleId " << ModuleId << "\n");
581 if (FatbinWrapper->PrelinkedFatbins) {
582 // This is RDC compilation, just insert the FatbinWrapper and ignore the
583 // ModuleId coming from the link.stub.
584 HandleToBinaryInfo.emplace(Handle, BinaryInfo{FatbinWrapper, {}});
585
586 // Initialize GlobalLinkedBinaries with prelinked fatbins.
587 void *Ptr = FatbinWrapper->PrelinkedFatbins[0];
588 for (int I = 0; Ptr != nullptr;
589 ++I, Ptr = FatbinWrapper->PrelinkedFatbins[I]) {
590 PROTEUS_DBG(Logger::logs("proteus")
591 << "I " << I << " PrelinkedFatbin " << Ptr << "\n");
592 GlobalLinkedBinaries.insert(Ptr);
593 }
594 } else {
595 // This is non-RDC compilation, associate the ModuleId of the JIT bitcode in
596 // the module with the FatbinWrapper.
597 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
598 HandleToBinaryInfo.emplace(Handle, BinaryInfo{FatbinWrapper, {ModuleId}});
599 }
600}
601
602template <typename ImplT> void JitEngineDevice<ImplT>::registerFatBinaryEnd() {
603 PROTEUS_DBG(Logger::logs("proteus") << "Register fatbinary end\n");
604 // Erase linked binaries for which we have LLVM IR code, those binaries are
605 // stored in the ModuleIdToFatBinary map.
606 for (auto &[ModuleId, FatbinWrapper] : ModuleIdToFatBinary)
607 GlobalLinkedBinaries.erase((void *)FatbinWrapper->Binary);
608
609 CurHandle = nullptr;
610}
611
612template <typename ImplT>
614 char *KernelName,
615 int32_t *RCIndices,
616 int32_t *RCTypes,
617 int32_t NumRCs) {
618 PROTEUS_DBG(Logger::logs("proteus") << "Register function " << Kernel
619 << " To Handle " << Handle << "\n");
620 // NOTE: HIP RDC might call multiple times the registerFunction for the same
621 // kernel, which has weak linkage, when it comes from different translation
622 // units. Either the first or the second call can prevail and should be
623 // equivalent. We let the first one prevail.
624 if (JITKernelInfoMap.contains(Kernel)) {
625 PROTEUS_DBG(Logger::logs("proteus")
626 << "Warning: duplicate register function for kernel " +
627 std::string(KernelName)
628 << "\n");
629 return;
630 }
631
632 if (!HandleToBinaryInfo.count(Handle))
633 PROTEUS_FATAL_ERROR("Expected Handle in map");
634 BinaryInfo &BinInfo = HandleToBinaryInfo[Handle];
635
636 JITKernelInfoMap[Kernel] =
638}
639
640template <typename ImplT>
642 const char *ModuleId) {
643 PROTEUS_DBG(Logger::logs("proteus")
644 << "Register linked binary FatBinary " << FatbinWrapper
645 << " Binary " << (void *)FatbinWrapper->Binary << " ModuleId "
646 << ModuleId << "\n");
647 if (CurHandle) {
648 if (!HandleToBinaryInfo.count(CurHandle))
649 PROTEUS_FATAL_ERROR("Expected CurHandle in map");
650
651 HandleToBinaryInfo[CurHandle].addModuleId(ModuleId);
652 } else
653 GlobalLinkedModuleIds.push_back(ModuleId);
654
655 ModuleIdToFatBinary[ModuleId] = FatbinWrapper;
656}
657
658template <typename ImplT>
660 SmallVector<std::unique_ptr<Module>> &LinkedModules,
661 std::unique_ptr<Module> LTOModule) {
662 if (LinkedModules.empty())
663 PROTEUS_FATAL_ERROR("Expected jit module");
664
665 auto LinkedModule = proteus::linkModules(getLLVMContext(), LinkedModules);
666
667 // Last, link in the LTO module, if there is one. The LTO module includes code
668 // post-optimization, which reduces specialization opportunities for proteus
669 // (e.g., due to inlining). Due to that, we selectively link as needed from
670 // it, to import definitions outside the proteus-compiled bitcode.
671 if (LTOModule) {
672 Linker IRLinker(*LinkedModule);
673 // Remove internal linkage from functions in the LTO module to make them
674 // linkable.
675 for (auto &F : *LTOModule) {
676 if (F.hasInternalLinkage())
677 F.setLinkage(GlobalValue::ExternalLinkage);
678 }
679
680 if (IRLinker.linkInModule(std::move(LTOModule),
681 Linker::Flags::LinkOnlyNeeded))
682 PROTEUS_FATAL_ERROR("Linking failed");
683 }
684
685 return LinkedModule;
686}
687
688} // namespace proteus
689
690#endif
void const char * ModuleId
Definition CompilerInterfaceDevice.cpp:31
void * FatbinWrapper
Definition CompilerInterfaceDevice.cpp:30
void char int32_t int32_t * RCTypes
Definition CompilerInterfaceDevice.cpp:51
const char * VarName
Definition CompilerInterfaceDevice.cpp:20
void char int32_t * RCIndices
Definition CompilerInterfaceDevice.cpp:51
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
void char int32_t int32_t int32_t NumRCs
Definition CompilerInterfaceDevice.cpp:51
void * Kernel
Definition CompilerInterfaceDevice.cpp:50
#define PROTEUS_DBG(x)
Definition Debug.h:7
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
void getLambdaJitValues(Module &M, StringRef FnName, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineHost.cpp:273
#define TIMESCOPE(x)
Definition TimeTracing.hpp:35
Definition JitEngineDevice.hpp:78
FatbinWrapperT * getFatbinWrapper() const
Definition JitEngineDevice.hpp:93
bool hasModuleHash() const
Definition JitEngineDevice.hpp:101
void setModule(std::unique_ptr< Module > Module)
Definition JitEngineDevice.hpp:97
bool hasModule() const
Definition JitEngineDevice.hpp:95
auto & getModuleIds()
Definition JitEngineDevice.hpp:122
Module & getModule() const
Definition JitEngineDevice.hpp:96
void updateModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:104
HashT getModuleHash() const
Definition JitEngineDevice.hpp:102
void addModuleId(const char *ModuleId)
Definition JitEngineDevice.hpp:118
CallGraph & getCallGraph()
Definition JitEngineDevice.hpp:111
BinaryInfo(FatbinWrapperT *FatbinWrapper, SmallVector< std::string > &&LinkedModuleIds)
Definition JitEngineDevice.hpp:88
void setModuleHash(HashT HashValue)
Definition JitEngineDevice.hpp:103
Definition CompilationTask.hpp:17
static CompilerAsync & instance(int NumThreads)
Definition CompilerAsync.hpp:49
void joinAllThreads()
Definition CompilerAsync.hpp:86
std::unique_ptr< MemoryBuffer > compile(CompilationTask &&CT)
Definition CompilerSync.hpp:21
static CompilerSync & instance()
Definition CompilerSync.hpp:16
Definition Hashing.hpp:19
std::string toString() const
Definition Hashing.hpp:27
Definition JitEngineDevice.hpp:125
const std::string & getName() const
Definition JitEngineDevice.hpp:150
void createStaticHash(HashT ModuleHash)
Definition JitEngineDevice.hpp:161
void * getKernel() const
Definition JitEngineDevice.hpp:146
bool hasModule() const
Definition JitEngineDevice.hpp:153
JITKernelInfo(void *Kernel, BinaryInfo &BinInfo, char const *Name, int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs)
Definition JitEngineDevice.hpp:137
const auto & getRCTypes() const
Definition JitEngineDevice.hpp:152
const HashT getStaticHash() const
Definition JitEngineDevice.hpp:160
BinaryInfo & getBinaryInfo() const
Definition JitEngineDevice.hpp:155
Module & getModule() const
Definition JitEngineDevice.hpp:154
const auto & getRCIndices() const
Definition JitEngineDevice.hpp:151
void setModule(std::unique_ptr< llvm::Module > Mod)
Definition JitEngineDevice.hpp:156
void setLambdaCalleeInfo(SmallVector< std::pair< std::string, StringRef > > &&LambdaInfo)
Definition JitEngineDevice.hpp:168
const auto & getLambdaCalleeInfo()
Definition JitEngineDevice.hpp:167
bool hasLambdaCalleeInfo()
Definition JitEngineDevice.hpp:166
bool hasStaticHash() const
Definition JitEngineDevice.hpp:159
Definition JitCache.hpp:32
void printStats()
Definition JitCache.hpp:64
Definition JitEngineDevice.hpp:176
JitCache< KernelFunction_t > CodeCache
Definition JitEngineDevice.hpp:383
~JitEngineDevice()
Definition JitEngineDevice.hpp:378
LLVMContext & getLLVMContext()
Definition JitEngineDevice.hpp:391
typename DeviceTraits< ImplT >::DeviceError_t DeviceError_t
Definition JitEngineDevice.hpp:184
void registerLinkedBinary(FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:641
JitEngineDevice()
Definition JitEngineDevice.hpp:376
std::unordered_map< const void *, BinaryInfo > HandleToBinaryInfo
Definition JitEngineDevice.hpp:279
DenseMap< const void *, JITKernelInfo > JITKernelInfoMap
Definition JitEngineDevice.hpp:393
typename DeviceTraits< ImplT >::DeviceStream_t DeviceStream_t
Definition JitEngineDevice.hpp:185
void registerFatBinaryEnd()
Definition JitEngineDevice.hpp:602
std::unique_ptr< Module > linkJitModule(SmallVector< std::unique_ptr< Module > > &LinkedModules, std::unique_ptr< Module > LTOModule=nullptr)
Definition JitEngineDevice.hpp:659
Module & getModule(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:193
JitStorageCache< KernelFunction_t > StorageCache
Definition JitEngineDevice.hpp:384
bool containsJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:283
std::optional< std::reference_wrapper< JITKernelInfo > > getJITKernelInfo(const void *Func)
Definition JitEngineDevice.hpp:288
SmallPtrSet< void *, 8 > GlobalLinkedBinaries
Definition JitEngineDevice.hpp:281
std::unordered_map< std::string, const void * > VarNameToDevPtr
Definition JitEngineDevice.hpp:386
void registerFunction(void *Handle, void *Kernel, char *KernelName, int32_t *RCIndices, int32_t *RCTypes, int32_t NumRCs)
Definition JitEngineDevice.hpp:613
void finalize()
Definition JitEngineDevice.hpp:312
void * CurHandle
Definition JitEngineDevice.hpp:277
void registerFatBinary(void *Handle, FatbinWrapperT *FatbinWrapper, const char *ModuleId)
Definition JitEngineDevice.hpp:573
DeviceError_t compileAndRun(JITKernelInfo &KernelInfo, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, typename DeviceTraits< ImplT >::DeviceStream_t Stream)
Definition JitEngineDevice.hpp:446
std::unordered_map< std::string, FatbinWrapperT * > ModuleIdToFatBinary
Definition JitEngineDevice.hpp:278
typename DeviceTraits< ImplT >::KernelFunction_t KernelFunction_t
Definition JitEngineDevice.hpp:186
void insertRegisterVar(const char *VarName, const void *Addr)
Definition JitEngineDevice.hpp:266
SmallVector< std::string > GlobalLinkedModuleIds
Definition JitEngineDevice.hpp:280
std::string DeviceArch
Definition JitEngineDevice.hpp:385
HashT getStaticHash(JITKernelInfo &KernelInfo)
Definition JitEngineDevice.hpp:295
void getLambdaJitValues(JITKernelInfo &KernelInfo, SmallVector< RuntimeConstant > &LambdaJitValuesVec)
Definition JitEngineDevice.hpp:229
Definition JitEngine.hpp:44
struct proteus::JitEngine::@0 Config
void runCleanupPassPipeline(Module &M)
Definition JitEngine.cpp:76
Definition JitStorageCache.hpp:36
void printStats()
Definition JitStorageCache.hpp:65
Definition LambdaRegistry.hpp:19
std::optional< DenseMap< StringRef, SmallVector< RuntimeConstant > >::iterator > matchJitVariableMap(StringRef FnName)
Definition LambdaRegistry.hpp:27
const SmallVector< RuntimeConstant > & getJitVariables(StringRef LambdaTypeRef)
Definition LambdaRegistry.hpp:74
static LambdaRegistry & instance()
Definition LambdaRegistry.hpp:21
bool empty()
Definition LambdaRegistry.hpp:78
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:18
Definition JitEngine.cpp:20
std::unique_ptr< Module > linkModules(LLVMContext &Ctx, SmallVector< std::unique_ptr< Module > > &LinkedModules)
Definition CoreLLVM.hpp:153
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:73
cudaError_t launchKernelDirect(void *KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:20
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:51
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, const void * > &VarNameToDevPtr)
Definition CoreDeviceCUDA.hpp:27
HashT hashCombine(HashT A, HashT B)
Definition Hashing.hpp:68
void pruneIR(Module &M, bool UnsetExternallyInitialized=true)
Definition CoreLLVM.hpp:193
void * resolveDeviceGlobalAddr(const void *Addr)
Definition CoreDeviceCUDA.hpp:12
void internalize(Module &M, StringRef PreserveFunctionName)
Definition CoreLLVM.hpp:228
Definition Hashing.hpp:94
Definition JitEngineDevice.hpp:174
Definition JitEngineDevice.hpp:71
const char * Binary
Definition JitEngineDevice.hpp:74
void ** PrelinkedFatbins
Definition JitEngineDevice.hpp:75
int32_t Magic
Definition JitEngineDevice.hpp:72
int32_t Version
Definition JitEngineDevice.hpp:73