Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitFrontend.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_JIT_DEV_HPP
2#define PROTEUS_JIT_DEV_HPP
3
4#include <llvm/ADT/StringRef.h>
5#include <llvm/IR/IRBuilder.h>
6#include <llvm/IR/Module.h>
7#include <llvm/IR/Verifier.h>
8#include <llvm/Support/Debug.h>
9#include <llvm/Support/MemoryBuffer.h>
10#include <llvm/TargetParser/Host.h>
11#include <llvm/TargetParser/Triple.h>
12
13#include <deque>
14
15#include "proteus/Error.h"
19
20#include <iostream>
21
22namespace proteus {
23using namespace llvm;
24
25class JitModule {
26private:
27 std::unique_ptr<LLVMContext> Ctx;
28 std::unique_ptr<Module> Mod;
29 std::unique_ptr<MemoryBuffer> CompiledObject;
30
31 std::deque<Func> Functions;
32 TargetModelType TargetModel;
33 std::string TargetTriple;
34 Dispatcher &Dispatch;
35
36 bool IsCompiled = false;
37
38 template <typename... ArgT> struct KernelHandle {
39 Func &F;
40 JitModule &M;
41
42 // Launch with type-safety.
43 [[nodiscard]] auto launch(LaunchDims Grid, LaunchDims Block,
44 uint64_t ShmemBytes, void *Stream, ArgT... Args) {
45 // Create the type-safe tuple.
46 auto Tup = std::make_tuple(static_cast<ArgT>(Args)...);
47
48 // Create the ArrayRef<void*> pointing at each tuple element.
49 std::array<void *, sizeof...(ArgT)> Ptrs;
50 std::apply(
51 [&](auto &...Elts) {
52 size_t I = 0;
53 ((Ptrs[I++] = (void *)&Elts), ...);
54 },
55 Tup);
56
57 // Call launch through module.
58 // TODO: should it use the dispatcher directly?
59 return M.launch(F, Grid, Block, Ptrs, ShmemBytes, Stream);
60 }
61
62 Func *operator->() { return &F; }
63 };
64
65 TargetModelType getTargetModel(StringRef Target) {
66 if (Target == "host" || Target == "native") {
68 }
69
70 if (Target == "cuda") {
72 }
73
74 if (Target == "hip") {
76 }
77
78 PROTEUS_FATAL_ERROR("Unsupported target " + Target);
79 }
80
81 std::string getTargetTriple(TargetModelType Model) {
82 switch (Model) {
84 return sys::getProcessTriple();
86 return "nvptx64-nvidia-cuda";
88 return "amdgcn-amd-amdhsa";
89 default:
90 PROTEUS_FATAL_ERROR("Unsupported target model");
91 }
92 }
93
94 bool isDeviceModule() {
95 return ((TargetModel == TargetModelType::CUDA) ||
96 (TargetModel == TargetModelType::HIP));
97 }
98
99 void setKernel(Func &F) {
100 switch (TargetModel) {
102 NamedMDNode *MD = Mod->getOrInsertNamedMetadata("nvvm.annotations");
103
104 Metadata *MDVals[] = {
105 ConstantAsMetadata::get(F.getFunction()),
106 MDString::get(*Ctx, "kernel"),
107 ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(*Ctx), 1))};
108 // Append metadata to nvvm.annotations.
109 MD->addOperand(MDNode::get(*Ctx, MDVals));
110
111 // Add a function attribute for the kernel.
112 F.getFunction()->addFnAttr(Attribute::get(*Ctx, "kernel"));
113 return;
114 }
116 F.getFunction()->setCallingConv(CallingConv::AMDGPU_KERNEL);
117 return;
119 PROTEUS_FATAL_ERROR("Host does not support setKernel");
120 default:
121 PROTEUS_FATAL_ERROR("Unsupported target " + TargetTriple);
122 }
123 }
124
125public:
126 JitModule(StringRef Target = "host")
127 : Ctx{std::make_unique<LLVMContext>()},
128 Mod{std::make_unique<Module>("JitModule", *Ctx)},
129 TargetModel{getTargetModel(Target)},
130 TargetTriple(getTargetTriple(TargetModel)),
131 Dispatch(Dispatcher::getDispatcher(TargetModel)) {}
132
133 // Disable copy and move constructors.
134 JitModule(const JitModule &) = delete;
135 JitModule &operator=(const JitModule &) = delete;
136 JitModule(JitModule &&) = delete;
138
139 template <typename RetT, typename... ArgT> Func &addFunction(StringRef Name) {
140 Mod->setTargetTriple(TargetTriple);
141 FunctionCallee FC;
142 FC = Mod->getOrInsertFunction(Name, TypeMap<RetT>::get(*Ctx),
143 TypeMap<ArgT>::get(*Ctx)...);
144 Function *F = dyn_cast<Function>(FC.getCallee());
145 if (!F)
146 PROTEUS_FATAL_ERROR("Unexpected");
147 auto &Fn = Functions.emplace_back(FC);
148
149 Fn.declArgs<ArgT...>();
150 return Fn;
151 }
152
153 const Module &getModule() const { return *Mod; }
154
155 template <typename... ArgT> KernelHandle<ArgT...> addKernel(StringRef Name) {
156 if (!isDeviceModule())
157 PROTEUS_FATAL_ERROR("Expected a device module for addKernel");
158
159 Mod->setTargetTriple(TargetTriple);
160 FunctionCallee FC;
161 FC = Mod->getOrInsertFunction(Name, TypeMap<void>::get(*Ctx),
162 TypeMap<ArgT>::get(*Ctx)...);
163 Function *F = dyn_cast<Function>(FC.getCallee());
164 if (!F)
165 PROTEUS_FATAL_ERROR("Unexpected");
166 auto &Fn = Functions.emplace_back(FC);
167
168 Fn.declArgs<ArgT...>();
169
170 setKernel(Fn);
171 return KernelHandle<ArgT...>{Fn, *this};
172 }
173
174 void compile(bool Verify = false) {
175 if (Verify)
176 if (verifyModule(*Mod, &errs())) {
177 PROTEUS_FATAL_ERROR("Broken module found, JIT compilation aborted!");
178 }
179
180 Dispatch.compile(std::move(Mod));
181 IsCompiled = true;
182 }
183
184 template <typename Ret, typename... ArgT> Ret run(Func &F, ArgT... Args) {
185 if (!IsCompiled)
186 PROTEUS_FATAL_ERROR("Expected compiled JIT module");
187 return Dispatch.run<Ret>(F.getName(), Args...);
188 }
189
190 template <typename... ArgT>
191 KernelHandle<ArgT...> getKernelHandle(StringRef Name) {
192 // Find the kernel function and return a kernel handle.
193 for (auto &Fn : Functions) {
194 if (Fn.getName() == Name)
195 return {Fn, *this};
196 }
197 PROTEUS_FATAL_ERROR("Kernel not found: " + Name);
198 // TODO: add type-checking to make sure parameters match the function
199 // signature.
200 }
201
202 auto launch(Func &F, LaunchDims GridDim, LaunchDims BlockDim,
203 ArrayRef<void *> KernelArgs, uint64_t ShmemSize, void *Stream) {
204 if (!IsCompiled)
205 PROTEUS_FATAL_ERROR("Expected compiled JIT module");
206 return Dispatch.launch(F.getName(), GridDim, BlockDim, KernelArgs,
207 ShmemSize, Stream);
208 }
209
210 auto launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim,
211 ArrayRef<void *> KernelArgs, uint64_t ShmemSize, void *Stream) {
212 if (!IsCompiled)
213 PROTEUS_FATAL_ERROR("Expected compiled JIT module");
214 // TODO: check that KernelName is valid.
215 return Dispatch.launch(KernelName, GridDim, BlockDim, KernelArgs, ShmemSize,
216 Stream);
217 }
218
219 void print() { Mod->print(outs(), nullptr); }
220};
221
222template <typename RetT, typename... ArgT> void Func::call(StringRef Name) {
223 auto *F = getFunction();
224 Module &M = *F->getParent();
225 LLVMContext &Ctx = F->getContext();
226 FunctionCallee Callee = M.getOrInsertFunction(Name, TypeMap<RetT>::get(Ctx),
227 TypeMap<ArgT>::get(Ctx)...);
228 IRB.CreateCall(Callee);
229}
230
231} // namespace proteus
232
233#endif
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
char int void ** Args
Definition CompilerInterfaceHost.cpp:20
TargetModelType
Definition Dispatcher.hpp:12
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
Definition Dispatcher.hpp:45
Ret run(StringRef FuncName, ArgT... Args)
Definition Dispatcher.hpp:57
virtual void compile(std::unique_ptr< Module > M)=0
virtual DispatchResult launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)=0
Definition Func.hpp:19
Function * getFunction()
Definition Func.cpp:58
StringRef getName()
Definition Func.hpp:117
void call(StringRef Name)
Definition JitFrontend.hpp:222
Definition JitFrontend.hpp:25
Ret run(Func &F, ArgT... Args)
Definition JitFrontend.hpp:184
auto launch(Func &F, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)
Definition JitFrontend.hpp:202
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
JitModule(StringRef Target="host")
Definition JitFrontend.hpp:126
Func & addFunction(StringRef Name)
Definition JitFrontend.hpp:139
const Module & getModule() const
Definition JitFrontend.hpp:153
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
auto launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)
Definition JitFrontend.hpp:210
KernelHandle< ArgT... > addKernel(StringRef Name)
Definition JitFrontend.hpp:155
void print()
Definition JitFrontend.hpp:219
void compile(bool Verify=false)
Definition JitFrontend.hpp:174
KernelHandle< ArgT... > getKernelHandle(StringRef Name)
Definition JitFrontend.hpp:191
Definition Dispatcher.cpp:14
Definition Hashing.hpp:94
Definition Dispatcher.hpp:14
Definition TypeMap.hpp:13