Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CompilationTask.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_COMPILATION_TASK_HPP
2#define PROTEUS_COMPILATION_TASK_HPP
3
4#include <llvm/Bitcode/BitcodeReader.h>
5#include <llvm/Bitcode/BitcodeWriter.h>
6
8#include "proteus/Config.hpp"
10#include "proteus/Debug.h"
11#include "proteus/Hashing.hpp"
12#include "proteus/Utils.h"
13
14namespace proteus {
15
16using namespace llvm;
17
19private:
20 MemoryBufferRef Bitcode;
21 HashT HashValue;
22 std::string KernelName;
23 std::string Suffix;
24 dim3 BlockDim;
25 dim3 GridDim;
26 SmallVector<RuntimeConstant> RCVec;
27 SmallVector<std::pair<std::string, StringRef>> LambdaCalleeInfo;
28 std::unordered_map<std::string, const void *> VarNameToDevPtr;
29 SmallPtrSet<void *, 8> GlobalLinkedBinaries;
30 std::string DeviceArch;
31 CodegenOption CGOption;
32 bool DumpIR;
33 bool RelinkGlobalsByCopy;
34 int MinBlocksPerSM;
35 bool SpecializeArgs;
36 bool SpecializeDims;
37 bool SpecializeDimsAssume;
38 bool SpecializeLaunchBounds;
39 char OptLevel;
40 unsigned CodegenOptLevel;
41 std::optional<std::string> PassPipeline;
42
43 std::unique_ptr<Module> cloneKernelModule(LLVMContext &Ctx) {
44 auto ClonedModule = parseBitcodeFile(Bitcode, Ctx);
45 if (auto E = ClonedModule.takeError()) {
46 PROTEUS_FATAL_ERROR("Failed to parse bitcode" + toString(std::move(E)));
47 }
48
49 return std::move(*ClonedModule);
50 }
51
52 void invokeOptimizeIR(Module &M) {
53#if PROTEUS_ENABLE_CUDA
54 // For CUDA we always run the optimization pipeline.
55 if (!PassPipeline)
56 optimizeIR(M, DeviceArch, OptLevel, CodegenOptLevel);
57 else
58 optimizeIR(M, DeviceArch, PassPipeline.value(), CodegenOptLevel);
59#elif PROTEUS_ENABLE_HIP
60 // For HIP we run the optimization pipeline only for Serial codegen. HIP RTC
61 // and Parallel codegen, which uses LTO, invoke optimization internally.
62 // TODO: Move optimizeIR inside the codegen routines?
63 if (CGOption == CodegenOption::Serial) {
64 if (!PassPipeline) {
65 optimizeIR(M, DeviceArch, OptLevel, CodegenOptLevel);
66 } else {
67 optimizeIR(M, DeviceArch, PassPipeline.value(), CodegenOptLevel);
68 }
69 }
70#else
71#error "JitEngineDevice requires PROTEUS_ENABLE_CUDA or PROTEUS_ENABLE_HIP"
72#endif
73 }
74
75public:
77 MemoryBufferRef Bitcode, HashT HashValue, const std::string &KernelName,
78 std::string &Suffix, dim3 BlockDim, dim3 GridDim,
79 const SmallVector<RuntimeConstant> &RCVec,
80 const SmallVector<std::pair<std::string, StringRef>> &LambdaCalleeInfo,
81 const std::unordered_map<std::string, const void *> &VarNameToDevPtr,
82 const SmallPtrSet<void *, 8> &GlobalLinkedBinaries,
83 const std::string &DeviceArch, const CodeGenerationConfig &CGConfig,
84 bool DumpIR, bool RelinkGlobalsByCopy)
85 : Bitcode(Bitcode), HashValue(HashValue), KernelName(KernelName),
86 Suffix(Suffix), BlockDim(BlockDim), GridDim(GridDim), RCVec(RCVec),
87 LambdaCalleeInfo(LambdaCalleeInfo), VarNameToDevPtr(VarNameToDevPtr),
88 GlobalLinkedBinaries(GlobalLinkedBinaries), DeviceArch(DeviceArch),
89 CGOption(CGConfig.codeGenOption()), DumpIR(DumpIR),
90 RelinkGlobalsByCopy(RelinkGlobalsByCopy),
91 MinBlocksPerSM(
92 CGConfig.minBlocksPerSM(BlockDim.x * BlockDim.y * BlockDim.z)),
93 SpecializeArgs(CGConfig.specializeArgs()),
94 SpecializeDims(CGConfig.specializeDims()),
95 SpecializeDimsAssume(CGConfig.specializeDimsAssume()),
96 SpecializeLaunchBounds(CGConfig.specializeLaunchBounds()),
97 OptLevel(CGConfig.optLevel()),
98 CodegenOptLevel(CGConfig.codeGenOptLevel()),
99 PassPipeline(CGConfig.optPipeline()) {
100 if (Config::get().ProteusTraceOutput >= 1) {
101 llvm::SmallString<128> S;
102 llvm::raw_svector_ostream OS(S);
103 OS << "[KernelConfig] ID:" << KernelName << " ";
104 CGConfig.dump(OS);
105 OS << "\n";
106 Logger::trace(OS.str());
107 }
108 }
109
110 // Delete copy operations.
113
114 // Use default move operations.
115 CompilationTask(CompilationTask &&) noexcept = default;
116 CompilationTask &operator=(CompilationTask &&) noexcept = default;
117
118 HashT getHashValue() const { return HashValue; }
119
120 std::unique_ptr<MemoryBuffer> compile() {
121 struct TimerRAII {
122 std::chrono::high_resolution_clock::time_point Start, End;
123 HashT HashValue;
124 TimerRAII(HashT HashValue) : HashValue(HashValue) {
125 if (Config::get().ProteusDebugOutput) {
126 Start = std::chrono::high_resolution_clock::now();
127 }
128 }
129
130 ~TimerRAII() {
131 if (Config::get().ProteusDebugOutput) {
132 auto End = std::chrono::high_resolution_clock::now();
133 auto Duration = End - Start;
134 auto Milliseconds =
135 std::chrono::duration_cast<std::chrono::milliseconds>(Duration)
136 .count();
137 Logger::logs("proteus")
138 << "Compiled HashValue " << HashValue.toString() << " for "
139 << Milliseconds << "ms\n";
140 }
141 }
142 } Timer{HashValue};
143
144 LLVMContext Ctx;
145 std::unique_ptr<Module> M = cloneKernelModule(Ctx);
146
147 std::string KernelMangled = (KernelName + Suffix);
148
149 PROTEUS_DBG(Logger::logfile(HashValue.toString() + ".input.ll", *M));
150
151 proteus::specializeIR(*M, KernelName, Suffix, BlockDim, GridDim, RCVec,
152 LambdaCalleeInfo, SpecializeArgs, SpecializeDims,
153 SpecializeDimsAssume, SpecializeLaunchBounds,
154 MinBlocksPerSM);
155
156 PROTEUS_DBG(Logger::logfile(HashValue.toString() + ".specialized.ll", *M));
157
158 replaceGlobalVariablesWithPointers(*M, VarNameToDevPtr);
159
160 invokeOptimizeIR(*M);
161 if (Config::get().ProteusTraceOutput == 2) {
162 llvm::outs() << "LLVM IR module post optimization " << *M << "\n";
163 }
164 if (DumpIR) {
165 const auto CreateDumpDirectory = []() {
166 const std::string DumpDirectory = ".proteus-dump";
167 std::filesystem::create_directory(DumpDirectory);
168 return DumpDirectory;
169 };
170
171 static const std::string DumpDirectory = CreateDumpDirectory();
172
173 saveToFile(DumpDirectory + "/device-jit-" + HashValue.toString() + ".ll",
174 *M);
175 }
176
177 auto ObjBuf =
178 proteus::codegenObject(*M, DeviceArch, GlobalLinkedBinaries, CGOption);
179
180 if (!RelinkGlobalsByCopy)
181 proteus::relinkGlobalsObject(ObjBuf->getMemBufferRef(), VarNameToDevPtr);
182
183 return ObjBuf;
184 }
185};
186
187} // namespace proteus
188
189#endif
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
#define PROTEUS_DBG(x)
Definition Debug.h:9
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
void saveToFile(llvm::StringRef Filepath, T &&Data)
Definition Utils.h:23
Definition Config.hpp:130
void dump(T &OS) const
Definition Config.hpp:250
Definition CompilationTask.hpp:18
CompilationTask & operator=(const CompilationTask &)=delete
CompilationTask(MemoryBufferRef Bitcode, HashT HashValue, const std::string &KernelName, std::string &Suffix, dim3 BlockDim, dim3 GridDim, const SmallVector< RuntimeConstant > &RCVec, const SmallVector< std::pair< std::string, StringRef > > &LambdaCalleeInfo, const std::unordered_map< std::string, const void * > &VarNameToDevPtr, const SmallPtrSet< void *, 8 > &GlobalLinkedBinaries, const std::string &DeviceArch, const CodeGenerationConfig &CGConfig, bool DumpIR, bool RelinkGlobalsByCopy)
Definition CompilationTask.hpp:76
HashT getHashValue() const
Definition CompilationTask.hpp:118
CompilationTask(CompilationTask &&) noexcept=default
CompilationTask(const CompilationTask &)=delete
std::unique_ptr< MemoryBuffer > compile()
Definition CompilationTask.hpp:120
static Config & get()
Definition Config.hpp:304
Definition Hashing.hpp:20
std::string toString() const
Definition Hashing.hpp:28
static void trace(llvm::StringRef Msg)
Definition Logger.hpp:30
static void logfile(const std::string &Filename, T &&Data)
Definition Logger.hpp:33
Definition TimeTracing.hpp:36
Definition Helpers.h:138
Definition BuiltinsCUDA.cpp:4
void optimizeIR(Module &M, StringRef Arch, char OptLevel, unsigned CodegenOptLevel)
Definition CoreLLVM.hpp:182
CodegenOption
Definition Config.hpp:14
std::unique_ptr< MemoryBuffer > codegenObject(Module &M, StringRef DeviceArch, SmallPtrSetImpl< void * > &GlobalLinkedBinaries, CodegenOption CGOption=CodegenOption::RTC)
Definition CoreLLVMCUDA.hpp:160
std::string toString(CodegenOption Option)
Definition Config.hpp:26