Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CompilationTask.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_COMPILATION_TASK_HPP
2#define PROTEUS_COMPILATION_TASK_HPP
3
5#include "proteus/Config.hpp"
8#include "proteus/Debug.h"
9#include "proteus/Hashing.hpp"
10#include "proteus/Utils.h"
11
12#include <llvm/Bitcode/BitcodeReader.h>
13#include <llvm/Bitcode/BitcodeWriter.h>
14
15namespace proteus {
16
17using namespace llvm;
18
20private:
21 MemoryBufferRef Bitcode;
22 HashT HashValue;
23 std::string KernelName;
24 std::string Suffix;
25 dim3 BlockDim;
26 dim3 GridDim;
27 SmallVector<RuntimeConstant> RCVec;
28 SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo;
29 std::unordered_map<std::string, GlobalVarInfo> VarNameToGlobalInfo;
30 SmallPtrSet<void *, 8> GlobalLinkedBinaries;
31 std::string DeviceArch;
32 CodegenOption CGOption;
33 bool DumpIR;
34 bool RelinkGlobalsByCopy;
35 int MinBlocksPerSM;
36 bool SpecializeArgs;
37 bool SpecializeDims;
38 bool SpecializeDimsRange;
39 bool SpecializeLaunchBounds;
40 char OptLevel;
41 unsigned CodegenOptLevel;
42 std::optional<std::string> PassPipeline;
43
44 std::unique_ptr<Module> cloneKernelModule(LLVMContext &Ctx) {
45 auto ClonedModule = parseBitcodeFile(Bitcode, Ctx);
46 if (auto E = ClonedModule.takeError()) {
47 reportFatalError("Failed to parse bitcode" + toString(std::move(E)));
48 }
49
50 return std::move(*ClonedModule);
51 }
52
53 void invokeOptimizeIR(Module &M) {
54#if PROTEUS_ENABLE_CUDA
55 // For CUDA we always run the optimization pipeline.
56 if (!PassPipeline)
57 optimizeIR(M, DeviceArch, OptLevel, CodegenOptLevel);
58 else
59 optimizeIR(M, DeviceArch, PassPipeline.value(), CodegenOptLevel);
60#elif PROTEUS_ENABLE_HIP
61 // For HIP we run the optimization pipeline only for Serial codegen. HIP RTC
62 // and Parallel codegen, which uses LTO, invoke optimization internally.
63 // TODO: Move optimizeIR inside the codegen routines?
64 if (CGOption == CodegenOption::Serial) {
65 if (!PassPipeline) {
66 optimizeIR(M, DeviceArch, OptLevel, CodegenOptLevel);
67 } else {
68 optimizeIR(M, DeviceArch, PassPipeline.value(), CodegenOptLevel);
69 }
70 }
71#else
72#error "JitEngineDevice requires PROTEUS_ENABLE_CUDA or PROTEUS_ENABLE_HIP"
73#endif
74 }
75
76public:
78 MemoryBufferRef Bitcode, HashT HashValue, const std::string &KernelName,
79 std::string &Suffix, dim3 BlockDim, dim3 GridDim,
80 const SmallVector<RuntimeConstant> &RCVec,
81 const SmallVector<std::pair<std::string, StringRef>> &LambdaCalleeInfo,
82 const std::unordered_map<std::string, GlobalVarInfo> &VarNameToGlobalInfo,
83 const SmallPtrSet<void *, 8> &GlobalLinkedBinaries,
84 const std::string &DeviceArch, const CodeGenerationConfig &CGConfig,
85 bool DumpIR, bool RelinkGlobalsByCopy)
86 : Bitcode(Bitcode), HashValue(HashValue), KernelName(KernelName),
87 Suffix(Suffix), BlockDim(BlockDim), GridDim(GridDim), RCVec(RCVec),
88 LambdaCalleeInfo(LambdaCalleeInfo),
89 VarNameToGlobalInfo(VarNameToGlobalInfo),
90 GlobalLinkedBinaries(GlobalLinkedBinaries), DeviceArch(DeviceArch),
91 CGOption(CGConfig.codeGenOption()), DumpIR(DumpIR),
92 RelinkGlobalsByCopy(RelinkGlobalsByCopy),
93 MinBlocksPerSM(
94 CGConfig.minBlocksPerSM(BlockDim.x * BlockDim.y * BlockDim.z)),
95 SpecializeArgs(CGConfig.specializeArgs()),
96 SpecializeDims(CGConfig.specializeDims()),
97 SpecializeDimsRange(CGConfig.specializeDimsRange()),
98 SpecializeLaunchBounds(CGConfig.specializeLaunchBounds()),
99 OptLevel(CGConfig.optLevel()),
100 CodegenOptLevel(CGConfig.codeGenOptLevel()),
101 PassPipeline(CGConfig.optPipeline()) {
102 if (Config::get().ProteusTraceOutput >= 1) {
103 llvm::SmallString<128> S;
104 llvm::raw_svector_ostream OS(S);
105 OS << "[KernelConfig] ID:" << KernelName << " ";
106 CGConfig.dump(OS);
107 OS << "\n";
108 Logger::trace(OS.str());
109 }
110 }
111
112 // Delete copy operations.
115
116 // Use default move operations.
117 CompilationTask(CompilationTask &&) noexcept = default;
118 CompilationTask &operator=(CompilationTask &&) noexcept = default;
119
120 HashT getHashValue() const { return HashValue; }
121
122 std::unique_ptr<MemoryBuffer> compile() {
123 struct TimerRAII {
124 std::chrono::high_resolution_clock::time_point Start, End;
125 HashT HashValue;
126 TimerRAII(HashT HashValue) : HashValue(HashValue) {
127 if (Config::get().ProteusDebugOutput) {
128 Start = std::chrono::high_resolution_clock::now();
129 }
130 }
131
132 ~TimerRAII() {
133 if (Config::get().ProteusDebugOutput) {
134 auto End = std::chrono::high_resolution_clock::now();
135 auto Duration = End - Start;
136 auto Milliseconds =
137 std::chrono::duration_cast<std::chrono::milliseconds>(Duration)
138 .count();
139 Logger::logs("proteus")
140 << "Compiled HashValue " << HashValue.toString() << " for "
141 << Milliseconds << "ms\n";
142 }
143 }
144 } Timer{HashValue};
145
146 LLVMContext Ctx;
147 std::unique_ptr<Module> M = cloneKernelModule(Ctx);
148
149 std::string KernelMangled = (KernelName + Suffix);
150
151 PROTEUS_DBG(Logger::logfile(HashValue.toString() + ".input.ll", *M));
152
153 proteus::specializeIR(*M, KernelName, Suffix, BlockDim, GridDim, RCVec,
154 LambdaCalleeInfo, SpecializeArgs, SpecializeDims,
155 SpecializeDimsRange, SpecializeLaunchBounds,
156 MinBlocksPerSM);
157
158 PROTEUS_DBG(Logger::logfile(HashValue.toString() + ".specialized.ll", *M));
159
160 replaceGlobalVariablesWithPointers(*M, VarNameToGlobalInfo);
161
162 invokeOptimizeIR(*M);
163 if (Config::get().ProteusTraceOutput == 2) {
164 llvm::outs() << "LLVM IR module post optimization " << *M << "\n";
165 }
166 if (DumpIR) {
167 const auto CreateDumpDirectory = []() {
168 const std::string DumpDirectory = ".proteus-dump";
169 std::filesystem::create_directory(DumpDirectory);
170 return DumpDirectory;
171 };
172
173 static const std::string DumpDirectory = CreateDumpDirectory();
174
175 saveToFile(DumpDirectory + "/device-jit-" + HashValue.toString() + ".ll",
176 *M);
177 }
178
179 auto ObjBuf =
180 proteus::codegenObject(*M, DeviceArch, GlobalLinkedBinaries, CGOption);
181
182 if (!RelinkGlobalsByCopy)
183 proteus::relinkGlobalsObject(ObjBuf->getMemBufferRef(),
184 VarNameToGlobalInfo);
185
186 return ObjBuf;
187 }
188};
189
190} // namespace proteus
191
192#endif
void char * KernelName
Definition CompilerInterfaceDevice.cpp:52
#define PROTEUS_DBG(x)
Definition Debug.h:9
void saveToFile(llvm::StringRef Filepath, T &&Data)
Definition Utils.h:24
Definition Config.hpp:130
void dump(T &OS) const
Definition Config.hpp:246
Definition CompilationTask.hpp:19
CompilationTask & operator=(const CompilationTask &)=delete
HashT getHashValue() const
Definition CompilationTask.hpp:120
CompilationTask(CompilationTask &&) noexcept=default
CompilationTask(const CompilationTask &)=delete
std::unique_ptr< MemoryBuffer > compile()
Definition CompilationTask.hpp:122
CompilationTask(MemoryBufferRef Bitcode, HashT HashValue, const std::string &KernelName, std::string &Suffix, dim3 BlockDim, dim3 GridDim, const SmallVector< RuntimeConstant > &RCVec, const SmallVector< std::pair< std::string, StringRef > > &LambdaCalleeInfo, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo, const SmallPtrSet< void *, 8 > &GlobalLinkedBinaries, const std::string &DeviceArch, const CodeGenerationConfig &CGConfig, bool DumpIR, bool RelinkGlobalsByCopy)
Definition CompilationTask.hpp:77
static Config & get()
Definition Config.hpp:300
Definition Hashing.hpp:21
std::string toString() const
Definition Hashing.hpp:29
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static void logfile(const std::string &Filename, T &&Data)
Definition Logger.hpp:33
Definition TimeTracing.hpp:40
Definition Helpers.h:141
Definition ObjectCacheChain.cpp:26
void optimizeIR(Module &M, StringRef Arch, char OptLevel, unsigned CodegenOptLevel)
Definition CoreLLVM.hpp:182
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
CodegenOption
Definition Config.hpp:14
std::unique_ptr< MemoryBuffer > codegenObject(Module &M, StringRef DeviceArch, SmallPtrSetImpl< void * > &GlobalLinkedBinaries, CodegenOption CGOption=CodegenOption::RTC)
Definition CoreLLVMCUDA.hpp:165
std::string toString(CodegenOption Option)
Definition Config.hpp:26