Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CompilationTask.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_COMPILATION_TASK_HPP
2#define PROTEUS_COMPILATION_TASK_HPP
3
4#include <llvm/Bitcode/BitcodeReader.h>
5#include <llvm/Bitcode/BitcodeWriter.h>
6
9#include "proteus/Debug.h"
10#include "proteus/Hashing.hpp"
11#include "proteus/Utils.h"
12
13namespace proteus {
14
15using namespace llvm;
16
18private:
19 std::reference_wrapper<const Module> KernelModule;
20 HashT HashValue;
21 std::string KernelName;
22 std::string Suffix;
23 dim3 BlockDim;
24 dim3 GridDim;
25 SmallVector<int32_t> RCIndices;
26 SmallVector<RuntimeConstant> RCVec;
27 SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo;
28 std::unordered_map<std::string, const void *> VarNameToDevPtr;
29 SmallPtrSet<void *, 8> GlobalLinkedBinaries;
30 std::string DeviceArch;
31 bool UseRTC;
32 bool DumpIR;
33 bool RelinkGlobalsByCopy;
34 bool SpecializeArgs;
35 bool SpecializeDims;
36 bool SpecializeLaunchBounds;
37
38 std::unique_ptr<Module> cloneKernelModule(LLVMContext &Ctx) {
39 SmallVector<char, 4096> ModuleStr;
40 raw_svector_ostream OS(ModuleStr);
41 WriteBitcodeToFile(KernelModule, OS);
42 StringRef ModuleStrRef = StringRef{ModuleStr.data(), ModuleStr.size()};
43 auto BufferRef = MemoryBufferRef{ModuleStrRef, ""};
44 auto ClonedModule = parseBitcodeFile(BufferRef, Ctx);
45 if (auto E = ClonedModule.takeError()) {
46 PROTEUS_FATAL_ERROR("Failed to parse bitcode" + toString(std::move(E)));
47 }
48
49 return std::move(*ClonedModule);
50 }
51
52public:
54 const Module &Mod, HashT HashValue, const std::string &KernelName,
55 std::string &Suffix, dim3 BlockDim, dim3 GridDim,
56 const SmallVector<int32_t> &RCIndices,
57 const SmallVector<RuntimeConstant> &RCVec,
58 const SmallVector<std::pair<std::string, StringRef>> &LambdaCalleeInfo,
59 const std::unordered_map<std::string, const void *> &VarNameToDevPtr,
60 const SmallPtrSet<void *, 8> &GlobalLinkedBinaries,
61 const std::string &DeviceArch, bool UseRTC, bool DumpIR,
62 bool RelinkGlobalsByCopy, bool SpecializeArgs, bool SpecializeDims,
63 bool SpecializeLaunchBounds)
64 : KernelModule(Mod), HashValue(HashValue), KernelName(KernelName),
65 Suffix(Suffix), BlockDim(BlockDim), GridDim(GridDim),
66 RCIndices(RCIndices), RCVec(RCVec), LambdaCalleeInfo(LambdaCalleeInfo),
67 VarNameToDevPtr(VarNameToDevPtr),
68 GlobalLinkedBinaries(GlobalLinkedBinaries), DeviceArch(DeviceArch),
69 UseRTC(UseRTC), DumpIR(DumpIR),
70 RelinkGlobalsByCopy(RelinkGlobalsByCopy),
71 SpecializeArgs(SpecializeArgs), SpecializeDims(SpecializeDims),
72 SpecializeLaunchBounds(SpecializeLaunchBounds) {}
73
74 // Delete copy operations.
77
78 // Use default move operations.
79 CompilationTask(CompilationTask &&) noexcept = default;
80 CompilationTask &operator=(CompilationTask &&) noexcept = default;
81
82 HashT getHashValue() const { return HashValue; }
83
84 std::unique_ptr<MemoryBuffer> compile() {
85#if PROTEUS_ENABLE_DEBUG
86 auto Start = std::chrono::high_resolution_clock::now();
87#endif
88
89 LLVMContext Ctx;
90 std::unique_ptr<Module> M = cloneKernelModule(Ctx);
91
92 std::string KernelMangled = (KernelName + Suffix);
93
94 proteus::specializeIR(*M, KernelName, Suffix, BlockDim, GridDim, RCIndices,
95 RCVec, LambdaCalleeInfo, SpecializeArgs,
96 SpecializeDims, SpecializeLaunchBounds);
97
98 replaceGlobalVariablesWithPointers(*M, VarNameToDevPtr);
99
100 // For HIP RTC codegen do not run the optimization pipeline since HIP
101 // RTC internally runs it. For the rest of cases, that is CUDA or HIP
102 // with our own codegen instead of RTC, run the target-specific
103 // optimization pipeline to optimize the LLVM IR before handing over
104 // to codegen.
105#if PROTEUS_ENABLE_CUDA
106 optimizeIR(*M, DeviceArch, '3', 3);
107#elif PROTEUS_ENABLE_HIP
108 if (!UseRTC)
109 optimizeIR(*M, DeviceArch, '3', 3);
110#else
111#error "JitEngineDevice requires PROTEUS_ENABLE_CUDA or PROTEUS_ENABLE_HIP"
112#endif
113
114 if (DumpIR) {
115 const auto CreateDumpDirectory = []() {
116 const std::string DumpDirectory = ".proteus-dump";
117 std::filesystem::create_directory(DumpDirectory);
118 return DumpDirectory;
119 };
120
121 static const std::string DumpDirectory = CreateDumpDirectory();
122
123 saveToFile(DumpDirectory + "/device-jit-" + HashValue.toString() + ".ll",
124 *M);
125 }
126
127 auto ObjBuf =
128 proteus::codegenObject(*M, DeviceArch, GlobalLinkedBinaries, UseRTC);
129
130 if (!RelinkGlobalsByCopy)
131 proteus::relinkGlobalsObject(ObjBuf->getMemBufferRef(), VarNameToDevPtr);
132
133#if PROTEUS_ENABLE_DEBUG
134 auto End = std::chrono::high_resolution_clock::now();
135 auto Duration = End - Start;
136 auto Milliseconds =
137 std::chrono::duration_cast<std::chrono::milliseconds>(Duration).count();
138 Logger::logs("proteus") << "Compiled HashValue " << HashValue.toString()
139 << " for " << Milliseconds << "ms\n";
140#endif
141
142 return ObjBuf;
143 }
144};
145
146} // namespace proteus
147
148#endif
void char int32_t * RCIndices
Definition CompilerInterfaceDevice.cpp:51
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
void saveToFile(llvm::StringRef Filepath, T &&Data)
Definition Utils.h:23
Definition CompilationTask.hpp:17
CompilationTask & operator=(const CompilationTask &)=delete
CompilationTask(const Module &Mod, HashT HashValue, const std::string &KernelName, std::string &Suffix, dim3 BlockDim, dim3 GridDim, const SmallVector< int32_t > &RCIndices, const SmallVector< RuntimeConstant > &RCVec, const SmallVector< std::pair< std::string, StringRef > > &LambdaCalleeInfo, const std::unordered_map< std::string, const void * > &VarNameToDevPtr, const SmallPtrSet< void *, 8 > &GlobalLinkedBinaries, const std::string &DeviceArch, bool UseRTC, bool DumpIR, bool RelinkGlobalsByCopy, bool SpecializeArgs, bool SpecializeDims, bool SpecializeLaunchBounds)
Definition CompilationTask.hpp:53
HashT getHashValue() const
Definition CompilationTask.hpp:82
CompilationTask(CompilationTask &&) noexcept=default
CompilationTask(const CompilationTask &)=delete
std::unique_ptr< MemoryBuffer > compile()
Definition CompilationTask.hpp:84
Definition Hashing.hpp:19
std::string toString() const
Definition Hashing.hpp:27
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:18
Definition JitEngine.cpp:20
void optimizeIR(Module &M, StringRef Arch, char OptLevel, unsigned CodegenOptLevel)
Definition CoreLLVM.hpp:147
std::unique_ptr< MemoryBuffer > codegenObject(Module &M, StringRef DeviceArch, SmallPtrSetImpl< void * > &GlobalLinkedBinaries, bool UseRTC=true)
Definition CoreLLVMCUDA.hpp:148