Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitFrontend.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_JIT_DEV_HPP
2#define PROTEUS_JIT_DEV_HPP
3
4#include <llvm/ADT/StringRef.h>
5#include <llvm/Bitcode/BitcodeWriter.h>
6#include <llvm/IR/IRBuilder.h>
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Verifier.h>
9#include <llvm/Support/Debug.h>
10#include <llvm/Support/MemoryBuffer.h>
11#include <llvm/TargetParser/Host.h>
12#include <llvm/TargetParser/Triple.h>
13
14#include <deque>
15
17#include "proteus/Error.h"
22
23#include <iostream>
24
25namespace proteus {
26using namespace llvm;
27
28class JitModule {
29private:
30 std::unique_ptr<LLVMContext> Ctx;
31 std::unique_ptr<Module> Mod;
32 std::unique_ptr<CompiledLibrary> Library;
33
34 std::deque<std::unique_ptr<FuncBase>> Functions;
35 TargetModelType TargetModel;
36 std::string TargetTriple;
37 Dispatcher &Dispatch;
38
39 HashT ModuleHash = 0;
40 bool IsCompiled = false;
41
42 template <typename... ArgT> struct KernelHandle;
43
44 template <typename RetT, typename... ArgT>
45 Func<RetT, ArgT...> &buildFuncFromArgsList(FunctionCallee FC,
47 auto TypedFn = std::make_unique<Func<RetT, ArgT...>>(*this, FC, Dispatch);
48 Func<RetT, ArgT...> &TypedFnRef = *TypedFn;
49 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(TypedFn));
50 Fn->declArgs<ArgT...>();
51 return TypedFnRef;
52 }
53
54 template <typename... ArgT>
55 KernelHandle<ArgT...> buildKernelFromArgsList(FunctionCallee FC,
57 auto TypedFn = std::make_unique<Func<void, ArgT...>>(*this, FC, Dispatch);
59 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(TypedFn));
60
61 Fn->declArgs<ArgT...>();
62
63 setKernel(*Fn);
64 return KernelHandle<ArgT...>{TypedFnRef, *this};
65 }
66
67 template <typename... ArgT> struct KernelHandle {
68 Func<void, ArgT...> &F;
69 JitModule &M;
70
71 void setLaunchBounds([[maybe_unused]] int MaxThreadsPerBlock,
72 [[maybe_unused]] int MinBlocksPerSM = 0) {
73 if (!M.isDeviceModule())
74 PROTEUS_FATAL_ERROR("Expected a device module for setLaunchBounds");
75
76 if (M.isCompiled())
77 PROTEUS_FATAL_ERROR("setLaunchBounds must be called before compile()");
78
79#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
80 Function *Fn = F.getFunction();
81 if (!Fn)
82 PROTEUS_FATAL_ERROR("Expected non-null Function");
83
85#else
86 PROTEUS_FATAL_ERROR("Unsupported target for setLaunchBounds");
87#endif
88 }
89
90 // Launch with type-safety.
91 [[nodiscard]] auto launch(LaunchDims Grid, LaunchDims Block,
92 uint64_t ShmemBytes, void *Stream, ArgT... Args) {
93 // Pointers to the local parameter copies.
94 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
95
96 if (!M.isCompiled())
97 M.compile();
98
99 auto GetKernelFunc = [&]() {
100 // Get the kernel func pointer directly from the Func object if
101 // available.
102 if (auto KernelFunc = F.getCompiledFunc()) {
103 return KernelFunc;
104 }
105
106 // Get the kernel func pointer from the Dispatch and store it to the
107 // Func object to avoid cache lookups.
108 // TODO: Re-think caching and dispatchers.
109 auto KernelFunc = reinterpret_cast<decltype(F.getCompiledFunc())>(
110 M.Dispatch.getFunctionAddress(F.getName(), M.ModuleHash,
111 M.getLibrary()));
112
113 F.setCompiledFunc(KernelFunc);
114
115 return KernelFunc;
116 };
117
118 return M.Dispatch.launch(reinterpret_cast<void *>(GetKernelFunc()), Grid,
120 }
121
122 FuncBase *operator->() { return &F; }
123 };
124
125 bool isDeviceModule() {
126 return ((TargetModel == TargetModelType::CUDA) ||
127 (TargetModel == TargetModelType::HIP));
128 }
129
130 void setKernel(FuncBase &F) {
131 switch (TargetModel) {
133 NamedMDNode *MD = Mod->getOrInsertNamedMetadata("nvvm.annotations");
134
135 Metadata *MDVals[] = {
136 ConstantAsMetadata::get(F.getFunction()),
137 MDString::get(*Ctx, "kernel"),
138 ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(*Ctx), 1))};
139 // Append metadata to nvvm.annotations.
140 MD->addOperand(MDNode::get(*Ctx, MDVals));
141
142 // Add a function attribute for the kernel.
143 F.getFunction()->addFnAttr(Attribute::get(*Ctx, "kernel"));
144 return;
145 }
147 F.getFunction()->setCallingConv(CallingConv::AMDGPU_KERNEL);
148 return;
150 PROTEUS_FATAL_ERROR("Host does not support setKernel");
151 default:
152 PROTEUS_FATAL_ERROR("Unsupported target " + TargetTriple);
153 }
154 }
155
156public:
158 : Ctx{std::make_unique<LLVMContext>()},
159 Mod{std::make_unique<Module>("JitModule", *Ctx)},
160 TargetModel{parseTargetModel(Target)},
161 TargetTriple(getTargetTriple(TargetModel)),
162 Dispatch(Dispatcher::getDispatcher(TargetModel)) {}
163
164 // Disable copy and move constructors.
165 JitModule(const JitModule &) = delete;
166 JitModule &operator=(const JitModule &) = delete;
167 JitModule(JitModule &&) = delete;
169
170 template <typename Sig> auto &addFunction(StringRef Name) {
171 using RetT = typename FnSig<Sig>::RetT;
172 using ArgT = typename FnSig<Sig>::ArgsTList;
173
174 if (IsCompiled)
176 "The module is compiled, no further code can be added");
177
178 Mod->setTargetTriple(TargetTriple);
180
181 Function *F = dyn_cast<Function>(FC.getCallee());
182 if (!F)
183 PROTEUS_FATAL_ERROR("Unexpected");
184
185 return buildFuncFromArgsList<RetT>(FC, ArgT{});
186 }
187
188 bool isCompiled() const { return IsCompiled; }
189
190 const Module &getModule() const { return *Mod; }
191
192 template <typename Sig> auto addKernel(StringRef Name) {
193 using RetT = typename FnSig<Sig>::RetT;
194 static_assert(std::is_void_v<RetT>, "Kernels must have void return type");
195 using ArgT = typename FnSig<Sig>::ArgsTList;
196
197 if (IsCompiled)
199 "The module is compiled, no further code can be added");
200
201 if (!isDeviceModule())
202 PROTEUS_FATAL_ERROR("Expected a device module for addKernel");
203
204 Mod->setTargetTriple(TargetTriple);
206 Function *F = dyn_cast<Function>(FC.getCallee());
207 if (!F)
208 PROTEUS_FATAL_ERROR("Unexpected");
209
210 return buildKernelFromArgsList(FC, ArgT{});
211 }
212
213 void compile(bool Verify = false) {
214 if (IsCompiled)
215 return;
216
217 if (Verify)
218 if (verifyModule(*Mod, &errs())) {
219 PROTEUS_FATAL_ERROR("Broken module found, JIT compilation aborted!");
220 }
221
224 WriteBitcodeToFile(*Mod, OS);
225
226 // Create a unique module hash based on the bitcode and append to all
227 // function names to make them unique.
228 // TODO: Is this necessary?
229 ModuleHash = hash(StringRef{Buffer.data(), Buffer.size()});
230 for (auto &JitF : Functions) {
231 JitF->setName(JitF->getName().str() + "$" + ModuleHash.toString());
232 }
233
234 if ((Library = Dispatch.lookupCompiledLibrary(ModuleHash))) {
235 IsCompiled = true;
236 return;
237 }
238
239 Library = std::make_unique<CompiledLibrary>(
240 Dispatch.compile(std::move(Ctx), std::move(Mod), ModuleHash));
241 IsCompiled = true;
242 }
243
244 HashT getModuleHash() const { return ModuleHash; }
245
246 Dispatcher &getDispatcher() const { return Dispatch; }
247
248 TargetModelType getTargetModel() const { return TargetModel; }
249
251 if (!IsCompiled)
252 compile();
253
254 if (!Library)
255 PROTEUS_FATAL_ERROR("Expected non-null library after compilation");
256
257 return *Library;
258 }
259
260 template <typename RetT, typename... ArgT>
262 return Mod->getOrInsertFunction(Name, TypeMap<RetT>::get(*Ctx),
263 TypeMap<ArgT>::get(*Ctx)...);
264 }
265
266 void print() { Mod->print(outs(), nullptr); }
267};
268
269template <typename Sig>
270std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>, Var &>
272 using RetT = typename FnSig<Sig>::RetT;
273 auto *F = getFunction();
274 LLVMContext &Ctx = F->getContext();
275
276 using RetT = typename FnSig<Sig>::RetT;
277 using ArgT = typename FnSig<Sig>::ArgsTList;
279 auto *Call = IRB.CreateCall(Callee);
280 Var &Ret = declVarInternal("ret", TypeMap<RetT>::get(Ctx));
281 Ret.storeValue(Call);
282 return Ret;
283}
284
285template <typename Sig>
286std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>, void>
288 using RetT = typename FnSig<Sig>::RetT;
289
290 using RetT = typename FnSig<Sig>::RetT;
291 using ArgT = typename FnSig<Sig>::ArgsTList;
293 IRB.CreateCall(Callee);
294}
295
296template <typename Sig, typename... ArgVars>
297std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>, Var &>
299 auto *F = getFunction();
300 LLVMContext &Ctx = F->getContext();
301
302 using RetT = typename FnSig<Sig>::RetT;
303 using ArgT = typename FnSig<Sig>::ArgsTList;
305 auto *Call = IRB.CreateCall(Callee, {ArgsVars.getValue()...});
306
307 Var &Ret = declVarInternal("ret", TypeMap<RetT>::get(Ctx));
308 Ret.storeValue(Call);
309 return Ret;
310}
311
312template <typename Sig, typename... ArgVars>
313std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>, void>
315 using RetT = typename FnSig<Sig>::RetT;
316 using ArgT = typename FnSig<Sig>::ArgsTList;
317
319 IRB.CreateCall(Callee, {ArgsVars.getValue()...});
320}
321
322template <typename RetT, typename... ArgT>
324 if (!J.isCompiled())
325 J.compile();
326
327 if (!CompiledFunc) {
328 CompiledFunc = reinterpret_cast<decltype(CompiledFunc)>(
329 J.getDispatcher().getFunctionAddress(getName(), J.getModuleHash(),
330 J.getLibrary()));
331 }
332
333 if (J.getTargetModel() != TargetModelType::HOST)
335 "Target is a GPU model, cannot directly run functions, use launch()");
336
337 if constexpr (std::is_void_v<RetT>)
338 Dispatch.run<RetT(ArgT...)>(reinterpret_cast<void *>(CompiledFunc),
339 Args...);
340 else
341 return Dispatch.run<RetT(ArgT...)>(reinterpret_cast<void *>(CompiledFunc),
342 Args...);
343}
344
345} // namespace proteus
346
347#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition Dispatcher.hpp:54
virtual std::unique_ptr< CompiledLibrary > lookupCompiledLibrary(HashT ModuleHash)=0
virtual std::unique_ptr< MemoryBuffer > compile(std::unique_ptr< LLVMContext > Ctx, std::unique_ptr< Module > M, HashT ModuleHash, bool DisableIROpt=false)=0
virtual void * getFunctionAddress(StringRef FunctionName, HashT ModuleHash, CompiledLibrary &Library)=0
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)=0
Definition Func.hpp:39
void declArgs()
Definition Func.hpp:147
Var & declVarInternal(StringRef Name, Type *Ty, Type *PointerElemType=nullptr)
Definition Func.cpp:28
std::enable_if_t<!std::is_void_v< typename FnSig< Sig >::RetT >, Var & > call(StringRef Name)
Definition JitFrontend.hpp:271
IRBuilder IRB
Definition Func.hpp:43
Function * getFunction()
Definition Func.cpp:78
std::string Name
Definition Func.hpp:50
JitModule & J
Definition Func.hpp:41
Definition Func.hpp:262
RetT operator()(ArgT... Args)
Definition JitFrontend.hpp:323
Definition Hashing.hpp:20
std::string toString() const
Definition Hashing.hpp:28
Definition JitFrontend.hpp:28
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
bool isCompiled() const
Definition JitFrontend.hpp:188
JitModule(StringRef Target="host")
Definition JitFrontend.hpp:157
TargetModelType getTargetModel() const
Definition JitFrontend.hpp:248
FunctionCallee getFunctionCallee(StringRef Name, ArgTypeList< ArgT... >)
Definition JitFrontend.hpp:261
const Module & getModule() const
Definition JitFrontend.hpp:190
HashT getModuleHash() const
Definition JitFrontend.hpp:244
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
CompiledLibrary & getLibrary()
Definition JitFrontend.hpp:250
Dispatcher & getDispatcher() const
Definition JitFrontend.hpp:246
auto & addFunction(StringRef Name)
Definition JitFrontend.hpp:170
auto addKernel(StringRef Name)
Definition JitFrontend.hpp:192
void print()
Definition JitFrontend.hpp:266
void compile(bool Verify=false)
Definition JitFrontend.hpp:213
Definition Helpers.h:138
Definition BuiltinsCUDA.cpp:4
TargetModelType
Definition TargetModel.hpp:14
void setLaunchBoundsForKernel(Function &F, int MaxThreadsPerSM, int MinBlocksPerSM=0)
Definition CoreLLVMCUDA.hpp:87
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
TargetModelType parseTargetModel(StringRef Target)
Definition TargetModel.hpp:16
std::string getTargetTriple(TargetModelType Model)
Definition TargetModel.hpp:40
Definition Hashing.hpp:147
Definition Dispatcher.hpp:16
Definition Func.hpp:26
Definition CompiledLibrary.hpp:12
Definition Func.hpp:27
Definition TypeMap.hpp:13
Definition Var.hpp:17
virtual void storeValue(Value *Val)=0