Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitFrontend.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_JIT_DEV_HPP
2#define PROTEUS_JIT_DEV_HPP
3
4#include <llvm/ADT/StringRef.h>
5#include <llvm/Bitcode/BitcodeWriter.h>
6#include <llvm/IR/IRBuilder.h>
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Verifier.h>
9#include <llvm/Support/Debug.h>
10#include <llvm/Support/MemoryBuffer.h>
11#include <llvm/TargetParser/Host.h>
12#include <llvm/TargetParser/Triple.h>
13
14#include <deque>
15
16#include "proteus/Error.h"
20
21#include <iostream>
22
23namespace proteus {
24using namespace llvm;
25
26class JitModule {
27private:
28 std::unique_ptr<LLVMContext> Ctx;
29 std::unique_ptr<Module> Mod;
30 std::unique_ptr<MemoryBuffer> ObjectModule;
31
32 std::deque<std::unique_ptr<FuncBase>> Functions;
33 TargetModelType TargetModel;
34 std::string TargetTriple;
35 Dispatcher &Dispatch;
36
37 HashT ModuleHash = 0;
38 bool IsCompiled = false;
39
40 template <typename... ArgT> struct KernelHandle {
41 Func<void, ArgT...> &F;
42 JitModule &M;
43
44 // Launch with type-safety.
45 [[nodiscard]] auto launch(LaunchDims Grid, LaunchDims Block,
46 uint64_t ShmemBytes, void *Stream, ArgT... Args) {
47 // Pointers to the local parameter copies.
48 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
49
50 if (!M.isCompiled())
51 M.compile();
52
53 auto GetKernelFunc = [&]() {
54 // Get the kernel func pointer directly from the Func object if
55 // available.
56 if (auto KernelFunc = F.getCompiledFunc()) {
57 return KernelFunc;
58 }
59
60 // Get the kernel func pointer from the Dispatch and store it to the
61 // Func object to avoid cache lookups.
62 // TODO: Re-think caching and dispatchers.
63 auto KernelFunc = reinterpret_cast<decltype(F.getCompiledFunc())>(
64 M.Dispatch.getFunctionAddress(F.getName(), M.getObjectModuleRef()));
65
66 F.setCompiledFunc(KernelFunc);
67
68 return KernelFunc;
69 };
70
71 return M.Dispatch.launch(reinterpret_cast<void *>(GetKernelFunc()), Grid,
72 Block, Ptrs, ShmemBytes, Stream);
73 }
74
75 FuncBase *operator->() { return &F; }
76 };
77
78 bool isDeviceModule() {
79 return ((TargetModel == TargetModelType::CUDA) ||
80 (TargetModel == TargetModelType::HIP));
81 }
82
83 void setKernel(FuncBase &F) {
84 switch (TargetModel) {
86 NamedMDNode *MD = Mod->getOrInsertNamedMetadata("nvvm.annotations");
87
88 Metadata *MDVals[] = {
89 ConstantAsMetadata::get(F.getFunction()),
90 MDString::get(*Ctx, "kernel"),
91 ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(*Ctx), 1))};
92 // Append metadata to nvvm.annotations.
93 MD->addOperand(MDNode::get(*Ctx, MDVals));
94
95 // Add a function attribute for the kernel.
96 F.getFunction()->addFnAttr(Attribute::get(*Ctx, "kernel"));
97 return;
98 }
100 F.getFunction()->setCallingConv(CallingConv::AMDGPU_KERNEL);
101 return;
103 PROTEUS_FATAL_ERROR("Host does not support setKernel");
104 default:
105 PROTEUS_FATAL_ERROR("Unsupported target " + TargetTriple);
106 }
107 }
108
109public:
110 JitModule(StringRef Target = "host")
111 : Ctx{std::make_unique<LLVMContext>()},
112 Mod{std::make_unique<Module>("JitModule", *Ctx)},
113 TargetModel{parseTargetModel(Target)},
114 TargetTriple(getTargetTriple(TargetModel)),
115 Dispatch(Dispatcher::getDispatcher(TargetModel)) {}
116
117 // Disable copy and move constructors.
118 JitModule(const JitModule &) = delete;
119 JitModule &operator=(const JitModule &) = delete;
120 JitModule(JitModule &&) = delete;
122
123 template <typename RetT, typename... ArgT>
124 Func<RetT, ArgT...> &addFunction(StringRef Name) {
125 if (IsCompiled)
127 "The module is compiled, no further code can be added");
128
129 Mod->setTargetTriple(TargetTriple);
130 FunctionCallee FC;
131 FC = Mod->getOrInsertFunction(Name, TypeMap<RetT>::get(*Ctx),
132 TypeMap<ArgT>::get(*Ctx)...);
133 Function *F = dyn_cast<Function>(FC.getCallee());
134 if (!F)
135 PROTEUS_FATAL_ERROR("Unexpected");
136 auto TypedFn = std::make_unique<Func<RetT, ArgT...>>(*this, FC, Dispatch);
137 Func<RetT, ArgT...> &TypedFnRef = *TypedFn;
138 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(TypedFn));
139
140 Fn->declArgs<ArgT...>();
141 return TypedFnRef;
142 }
143
144 bool isCompiled() const { return IsCompiled; }
145
146 const Module &getModule() const { return *Mod; }
147
148 template <typename... ArgT> KernelHandle<ArgT...> addKernel(StringRef Name) {
149 if (IsCompiled)
151 "The module is compiled, no further code can be added");
152
153 if (!isDeviceModule())
154 PROTEUS_FATAL_ERROR("Expected a device module for addKernel");
155
156 Mod->setTargetTriple(TargetTriple);
157 FunctionCallee FC;
158 FC = Mod->getOrInsertFunction(Name, TypeMap<void>::get(*Ctx),
159 TypeMap<ArgT>::get(*Ctx)...);
160 Function *F = dyn_cast<Function>(FC.getCallee());
161 if (!F)
162 PROTEUS_FATAL_ERROR("Unexpected");
163 auto TypedFn = std::make_unique<Func<void, ArgT...>>(*this, FC, Dispatch);
164 Func<void, ArgT...> &TypedFnRef = *TypedFn;
165 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(TypedFn));
166
167 Fn->declArgs<ArgT...>();
168
169 setKernel(*Fn);
170 return KernelHandle<ArgT...>{TypedFnRef, *this};
171 }
172
173 void compile(bool Verify = false) {
174 if (IsCompiled)
175 return;
176
177 if (Verify)
178 if (verifyModule(*Mod, &errs())) {
179 PROTEUS_FATAL_ERROR("Broken module found, JIT compilation aborted!");
180 }
181
182 SmallVector<char, 0> Buffer;
183 raw_svector_ostream OS(Buffer);
184 WriteBitcodeToFile(*Mod, OS);
185
186 // Create a unique module hash based on the bitcode and append to all
187 // function names to make them unique.
188 // TODO: This is not needed for GPU JIT modules since they are separate
189 // objects. However, CPU JIT modules end up in the same object through the
190 // ORC JIT singleton. Reconsider the CPU JIT process.
191 ModuleHash = hash(StringRef{Buffer.data(), Buffer.size()});
192 for (auto &JitF : Functions) {
193 JitF->setName(JitF->getName().str() + "$" + ModuleHash.toString());
194 }
195
196 if ((ObjectModule = Dispatch.lookupObjectModule(ModuleHash))) {
197 IsCompiled = true;
198 return;
199 }
200
201 ObjectModule = Dispatch.compile(std::move(Ctx), std::move(Mod), ModuleHash);
202 IsCompiled = true;
203 }
204
205 HashT getModuleHash() const { return ModuleHash; }
206
207 std::optional<MemoryBufferRef> getObjectModuleRef() const {
208 // For host JIT modules the ObjectModule is alway nullptr and unused by
209 // DispatcherHOST since it is unused by ORC JIT.
210 if (!ObjectModule)
211 return std::nullopt;
212
213 return ObjectModule->getMemBufferRef();
214 }
215
216 const Dispatcher &getDispatcher() const { return Dispatch; }
217
218 TargetModelType getTargetModel() const { return TargetModel; }
219
220 void print() { Mod->print(outs(), nullptr); }
221};
222
223template <typename RetT, typename... ArgT>
224std::enable_if_t<!std::is_void_v<RetT>, Var &> FuncBase::call(StringRef Name) {
225 auto *F = getFunction();
226 Module &M = *F->getParent();
227 LLVMContext &Ctx = F->getContext();
228 FunctionCallee Callee = M.getOrInsertFunction(Name, TypeMap<RetT>::get(Ctx),
229 TypeMap<ArgT>::get(Ctx)...);
230 Var &Ret = declVarInternal("ret", TypeMap<RetT>::get(Ctx));
231 auto *Call = IRB.CreateCall(Callee);
232 Ret.storeValue(Call);
233
234 return Ret;
235}
236
237template <typename RetT, typename... ArgT>
238std::enable_if_t<std::is_void_v<RetT>, void> FuncBase::call(StringRef Name) {
239 auto *F = getFunction();
240 Module &M = *F->getParent();
241 LLVMContext &Ctx = F->getContext();
242 FunctionCallee Callee = M.getOrInsertFunction(Name, TypeMap<RetT>::get(Ctx),
243 TypeMap<ArgT>::get(Ctx)...);
244 IRB.CreateCall(Callee);
245}
246
247template <typename RetT, typename... ArgT>
249 if (!J.isCompiled())
250 J.compile();
251
252 if (J.getTargetModel() != TargetModelType::HOST)
254 "Target is a GPU model, cannot directly run functions, use launch()");
255
256 return Dispatch.run<RetT>(getName(), J.getObjectModuleRef(), Args...);
257}
258
259} // namespace proteus
260
261#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition Dispatcher.hpp:46
virtual std::unique_ptr< MemoryBuffer > lookupObjectModule(HashT ModuleHash)=0
virtual std::unique_ptr< MemoryBuffer > compile(std::unique_ptr< LLVMContext > Ctx, std::unique_ptr< Module > M, HashT ModuleHash)=0
virtual DispatchResult launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream, std::optional< MemoryBufferRef > ObjectModule)=0
auto run(StringRef FuncName, std::optional< MemoryBufferRef > ObjectModule, ArgT &&...Args)
Definition Dispatcher.hpp:75
virtual void * getFunctionAddress(StringRef FunctionName, std::optional< MemoryBufferRef > ObjectModule)=0
Definition Func.hpp:22
void declArgs()
Definition Func.hpp:109
std::enable_if_t<!std::is_void_v< RetT >, Var & > call(StringRef Name)
Definition JitFrontend.hpp:224
Var & declVarInternal(StringRef Name, Type *Ty, Type *PointerElemType=nullptr)
Definition Func.cpp:18
IRBuilder IRB
Definition Func.hpp:26
Function * getFunction()
Definition Func.cpp:60
std::string Name
Definition Func.hpp:31
Definition Func.hpp:165
RetT operator()(ArgT... Args)
Definition JitFrontend.hpp:248
Definition Hashing.hpp:20
std::string toString() const
Definition Hashing.hpp:28
Definition JitFrontend.hpp:26
const Dispatcher & getDispatcher() const
Definition JitFrontend.hpp:216
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
bool isCompiled() const
Definition JitFrontend.hpp:144
JitModule(StringRef Target="host")
Definition JitFrontend.hpp:110
Func< RetT, ArgT... > & addFunction(StringRef Name)
Definition JitFrontend.hpp:124
TargetModelType getTargetModel() const
Definition JitFrontend.hpp:218
std::optional< MemoryBufferRef > getObjectModuleRef() const
Definition JitFrontend.hpp:207
const Module & getModule() const
Definition JitFrontend.hpp:146
HashT getModuleHash() const
Definition JitFrontend.hpp:205
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
KernelHandle< ArgT... > addKernel(StringRef Name)
Definition JitFrontend.hpp:148
void print()
Definition JitFrontend.hpp:220
void compile(bool Verify=false)
Definition JitFrontend.hpp:173
Definition Helpers.h:76
Definition CppJitModule.cpp:21
TargetModelType
Definition TargetModel.hpp:14
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
TargetModelType parseTargetModel(StringRef Target)
Definition TargetModel.hpp:16
std::string getTargetTriple(TargetModelType Model)
Definition TargetModel.hpp:32
Definition Hashing.hpp:147
Definition Dispatcher.hpp:15
Definition TypeMap.hpp:13
Definition Var.hpp:15
void storeValue(Value *Val)
Definition Var.cpp:102