Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitFrontend.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_JIT_DEV_HPP
2#define PROTEUS_JIT_DEV_HPP
3
4#include <llvm/ADT/StringRef.h>
5#include <llvm/Bitcode/BitcodeWriter.h>
6#include <llvm/IR/IRBuilder.h>
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Verifier.h>
9#include <llvm/Support/Debug.h>
10#include <llvm/Support/MemoryBuffer.h>
11#include <llvm/TargetParser/Host.h>
12#include <llvm/TargetParser/Triple.h>
13
14#include <deque>
15#include <type_traits>
16
18#include "proteus/Error.h"
23
24#include <iostream>
25
26namespace proteus {
27using namespace llvm;
28
29class JitModule {
30private:
31 std::unique_ptr<LLVMContext> Ctx;
32 std::unique_ptr<Module> Mod;
33 std::unique_ptr<CompiledLibrary> Library;
34
35 std::deque<std::unique_ptr<FuncBase>> Functions;
36 TargetModelType TargetModel;
37 std::string TargetTriple;
38 Dispatcher &Dispatch;
39
40 HashT ModuleHash = 0;
41 bool IsCompiled = false;
42
43 template <typename... ArgT> struct KernelHandle;
44
45 template <typename RetT, typename... ArgT>
46 Func<RetT, ArgT...> &buildFuncFromArgsList(FunctionCallee FC,
48 auto TypedFn = std::make_unique<Func<RetT, ArgT...>>(*this, FC, Dispatch);
49 Func<RetT, ArgT...> &TypedFnRef = *TypedFn;
50 Functions.emplace_back(std::move(TypedFn));
52 return TypedFnRef;
53 }
54
55 template <typename... ArgT>
56 KernelHandle<ArgT...> buildKernelFromArgsList(FunctionCallee FC,
58 auto TypedFn = std::make_unique<Func<void, ArgT...>>(*this, FC, Dispatch);
61 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(TypedFn));
62
63 setKernel(*Fn);
64 return KernelHandle<ArgT...>{TypedFnRef, *this};
65 }
66
67 template <typename... ArgT> struct KernelHandle {
68 Func<void, ArgT...> &F;
69 JitModule &M;
70
71 void setLaunchBounds([[maybe_unused]] int MaxThreadsPerBlock,
72 [[maybe_unused]] int MinBlocksPerSM = 0) {
73 if (!M.isDeviceModule())
74 PROTEUS_FATAL_ERROR("Expected a device module for setLaunchBounds");
75
76 if (M.isCompiled())
77 PROTEUS_FATAL_ERROR("setLaunchBounds must be called before compile()");
78
79#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
80 Function *Fn = F.getFunction();
81 if (!Fn)
82 PROTEUS_FATAL_ERROR("Expected non-null Function");
83
84 setLaunchBoundsForKernel(*Fn, MaxThreadsPerBlock, MinBlocksPerSM);
85#else
86 PROTEUS_FATAL_ERROR("Unsupported target for setLaunchBounds");
87#endif
88 }
89
90 // Launch with type-safety.
91 [[nodiscard]] auto launch(LaunchDims Grid, LaunchDims Block,
92 uint64_t ShmemBytes, void *Stream, ArgT... Args) {
93 // Pointers to the local parameter copies.
94 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
95
96 if (!M.isCompiled())
97 M.compile();
98
99 auto GetKernelFunc = [&]() {
100 // Get the kernel func pointer directly from the Func object if
101 // available.
102 if (auto KernelFunc = F.getCompiledFunc()) {
103 return KernelFunc;
104 }
105
106 // Get the kernel func pointer from the Dispatch and store it to the
107 // Func object to avoid cache lookups.
108 // TODO: Re-think caching and dispatchers.
109 auto KernelFunc = reinterpret_cast<decltype(F.getCompiledFunc())>(
110 M.Dispatch.getFunctionAddress(F.getName(), M.ModuleHash,
111 M.getLibrary()));
112
113 F.setCompiledFunc(KernelFunc);
114
115 return KernelFunc;
116 };
117
118 return M.Dispatch.launch(reinterpret_cast<void *>(GetKernelFunc()), Grid,
120 }
121
122 FuncBase *operator->() { return &F; }
123 };
124
125 bool isDeviceModule() {
126 return ((TargetModel == TargetModelType::CUDA) ||
127 (TargetModel == TargetModelType::HIP));
128 }
129
130 void setKernel(FuncBase &F) {
131 switch (TargetModel) {
133 NamedMDNode *MD = Mod->getOrInsertNamedMetadata("nvvm.annotations");
134
135 Metadata *MDVals[] = {
136 ConstantAsMetadata::get(F.getFunction()),
137 MDString::get(*Ctx, "kernel"),
138 ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(*Ctx), 1))};
139 // Append metadata to nvvm.annotations.
140 MD->addOperand(MDNode::get(*Ctx, MDVals));
141
142 // Add a function attribute for the kernel.
143 F.getFunction()->addFnAttr(Attribute::get(*Ctx, "kernel"));
144 return;
145 }
147 F.getFunction()->setCallingConv(CallingConv::AMDGPU_KERNEL);
148 return;
150 PROTEUS_FATAL_ERROR("Host does not support setKernel");
151 default:
152 PROTEUS_FATAL_ERROR("Unsupported target " + TargetTriple);
153 }
154 }
155
156public:
158 : Ctx{std::make_unique<LLVMContext>()},
159 Mod{std::make_unique<Module>("JitModule", *Ctx)},
160 TargetModel{parseTargetModel(Target)},
161 TargetTriple(getTargetTriple(TargetModel)),
162 Dispatch(Dispatcher::getDispatcher(TargetModel)) {}
163
164 // Disable copy and move constructors.
165 JitModule(const JitModule &) = delete;
166 JitModule &operator=(const JitModule &) = delete;
167 JitModule(JitModule &&) = delete;
169
170 template <typename Sig> auto &addFunction(StringRef Name) {
171 using RetT = typename FnSig<Sig>::RetT;
172 using ArgT = typename FnSig<Sig>::ArgsTList;
173
174 if (IsCompiled)
176 "The module is compiled, no further code can be added");
177
178 Mod->setTargetTriple(TargetTriple);
180
181 Function *F = dyn_cast<Function>(FC.getCallee());
182 if (!F)
183 PROTEUS_FATAL_ERROR("Unexpected");
184
185 return buildFuncFromArgsList<RetT>(FC, ArgT{});
186 }
187
188 bool isCompiled() const { return IsCompiled; }
189
190 const Module &getModule() const { return *Mod; }
191
192 template <typename Sig> auto addKernel(StringRef Name) {
193 using RetT = typename FnSig<Sig>::RetT;
194 static_assert(std::is_void_v<RetT>, "Kernels must have void return type");
195 using ArgT = typename FnSig<Sig>::ArgsTList;
196
197 if (IsCompiled)
199 "The module is compiled, no further code can be added");
200
201 if (!isDeviceModule())
202 PROTEUS_FATAL_ERROR("Expected a device module for addKernel");
203
204 Mod->setTargetTriple(TargetTriple);
206 Function *F = dyn_cast<Function>(FC.getCallee());
207 if (!F)
208 PROTEUS_FATAL_ERROR("Unexpected");
209
210 return buildKernelFromArgsList(FC, ArgT{});
211 }
212
213 void compile(bool Verify = false) {
214 if (IsCompiled)
215 return;
216
217 if (Verify)
218 if (verifyModule(*Mod, &errs())) {
219 PROTEUS_FATAL_ERROR("Broken module found, JIT compilation aborted!");
220 }
221
224 WriteBitcodeToFile(*Mod, OS);
225
226 // Create a unique module hash based on the bitcode and append to all
227 // function names to make them unique.
228 // TODO: Is this necessary?
229 ModuleHash = hash(StringRef{Buffer.data(), Buffer.size()});
230 for (auto &JitF : Functions) {
231 JitF->setName(JitF->getName().str() + "$" + ModuleHash.toString());
232 }
233
234 if ((Library = Dispatch.lookupCompiledLibrary(ModuleHash))) {
235 IsCompiled = true;
236 return;
237 }
238
239 Library = std::make_unique<CompiledLibrary>(
240 Dispatch.compile(std::move(Ctx), std::move(Mod), ModuleHash));
241 IsCompiled = true;
242 }
243
244 HashT getModuleHash() const { return ModuleHash; }
245
246 Dispatcher &getDispatcher() const { return Dispatch; }
247
248 TargetModelType getTargetModel() const { return TargetModel; }
249
251 if (!IsCompiled)
252 compile();
253
254 if (!Library)
255 PROTEUS_FATAL_ERROR("Expected non-null library after compilation");
256
257 return *Library;
258 }
259
260 template <typename RetT, typename... ArgT>
262 return Mod->getOrInsertFunction(Name, TypeMap<RetT>::get(*Ctx),
263 TypeMap<ArgT>::get(*Ctx)...);
264 }
265
266 void print() { Mod->print(outs(), nullptr); }
267};
268
269template <typename Sig>
270std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
273 using RetT = typename FnSig<Sig>::RetT;
274
275 using ArgT = typename FnSig<Sig>::ArgsTList;
277 auto *Call = IRB.CreateCall(Callee);
278 Var<RetT> Ret = declVar<RetT>("ret");
279 Ret.storeValue(Call);
280 return Ret;
281}
282
283template <typename Sig>
284std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>, void>
286 using RetT = typename FnSig<Sig>::RetT;
287 using ArgT = typename FnSig<Sig>::ArgsTList;
289 IRB.CreateCall(Callee);
291
292template <typename Sig, typename... ArgVars>
293std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
296
297 using RetT = typename FnSig<Sig>::RetT;
298 using ArgT = typename FnSig<Sig>::ArgsTList;
300 auto GetArgVal = [](auto &&Arg) {
301 using ArgVarT = std::decay_t<decltype(Arg)>;
302 if constexpr (std::is_pointer_v<typename ArgVarT::ValueType>)
303 return Arg.loadPointer();
304 else
305 return Arg.loadValue();
306 };
307
308 auto *Call = IRB.CreateCall(Callee, {GetArgVal(ArgsVars)...});
309
310 Var<RetT> Ret = declVar<RetT>("ret");
311 Ret.storeValue(Call);
312 return Ret;
313}
314
315template <typename Sig, typename... ArgVars>
316std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>, void>
318 using RetT = typename FnSig<Sig>::RetT;
319 using ArgT = typename FnSig<Sig>::ArgsTList;
320
322 auto GetArgVal = [](auto &&Arg) {
323 using ArgVarT = std::decay_t<decltype(Arg)>;
324 if constexpr (std::is_pointer_v<typename ArgVarT::ValueType>)
325 return Arg.loadPointer();
326 else
327 return Arg.loadValue();
328 };
329
330 IRB.CreateCall(Callee, {GetArgVal(ArgsVars)...});
331}
332
333template <typename RetT, typename... ArgT>
335 if (!J.isCompiled())
336 J.compile();
337
338 if (!CompiledFunc) {
339 CompiledFunc = reinterpret_cast<decltype(CompiledFunc)>(
340 J.getDispatcher().getFunctionAddress(getName(), J.getModuleHash(),
341 J.getLibrary()));
342 }
343
344 if (J.getTargetModel() != TargetModelType::HOST)
346 "Target is a GPU model, cannot directly run functions, use launch()");
347
348 if constexpr (std::is_void_v<RetT>)
349 Dispatch.run<RetT(ArgT...)>(reinterpret_cast<void *>(CompiledFunc),
350 Args...);
351 else
352 return Dispatch.run<RetT(ArgT...)>(reinterpret_cast<void *>(CompiledFunc),
353 Args...);
354}
355
356} // namespace proteus
357
358#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:20
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition Dispatcher.hpp:54
virtual std::unique_ptr< CompiledLibrary > lookupCompiledLibrary(HashT ModuleHash)=0
virtual std::unique_ptr< MemoryBuffer > compile(std::unique_ptr< LLVMContext > Ctx, std::unique_ptr< Module > M, HashT ModuleHash, bool DisableIROpt=false)=0
virtual void * getFunctionAddress(StringRef FunctionName, HashT ModuleHash, CompiledLibrary &Library)=0
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)=0
Definition Func.hpp:40
IRBuilder IRB
Definition Func.hpp:44
Function * getFunction()
Definition Func.cpp:66
std::string Name
Definition Func.hpp:47
JitModule & J
Definition Func.hpp:42
std::enable_if_t<!std::is_void_v< typename FnSig< Sig >::RetT >, Var< typename FnSig< Sig >::RetT > > call(StringRef Name)
Definition JitFrontend.hpp:272
Definition Func.hpp:252
RetT operator()(ArgT... Args)
Definition JitFrontend.hpp:334
void declArgs()
Definition Func.hpp:292
Definition Hashing.hpp:20
std::string toString() const
Definition Hashing.hpp:28
Definition JitFrontend.hpp:29
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
bool isCompiled() const
Definition JitFrontend.hpp:188
JitModule(StringRef Target="host")
Definition JitFrontend.hpp:157
TargetModelType getTargetModel() const
Definition JitFrontend.hpp:248
FunctionCallee getFunctionCallee(StringRef Name, ArgTypeList< ArgT... >)
Definition JitFrontend.hpp:261
const Module & getModule() const
Definition JitFrontend.hpp:190
HashT getModuleHash() const
Definition JitFrontend.hpp:244
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
CompiledLibrary & getLibrary()
Definition JitFrontend.hpp:250
Dispatcher & getDispatcher() const
Definition JitFrontend.hpp:246
auto & addFunction(StringRef Name)
Definition JitFrontend.hpp:170
auto addKernel(StringRef Name)
Definition JitFrontend.hpp:192
void print()
Definition JitFrontend.hpp:266
void compile(bool Verify=false)
Definition JitFrontend.hpp:213
Definition Helpers.h:138
Definition StorageCache.cpp:24
TargetModelType
Definition TargetModel.hpp:14
void setLaunchBoundsForKernel(Function &F, int MaxThreadsPerSM, int MinBlocksPerSM=0)
Definition CoreLLVMCUDA.hpp:87
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
TargetModelType parseTargetModel(StringRef Target)
Definition TargetModel.hpp:16
std::string getTargetTriple(TargetModelType Model)
Definition TargetModel.hpp:40
Definition Hashing.hpp:147
Definition Dispatcher.hpp:16
Definition Func.hpp:27
Definition CompiledLibrary.hpp:12
Definition Func.hpp:28
Definition TypeMap.hpp:13
Definition Var.hpp:94