Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CppJitModule.h
Go to the documentation of this file.
1#ifndef PROTEUS_CPPFRONTEND_H
2#define PROTEUS_CPPFRONTEND_H
3
5#include "proteus/Init.h"
6
7#include <algorithm>
8#include <sstream>
9#include <vector>
10
11namespace proteus {
12
13struct CompiledLibrary;
14class HashT;
15
17private:
18 TargetModelType TargetModel;
19 std::string Code;
20 std::unique_ptr<HashT> ModuleHash;
21 std::vector<std::string> ExtraArgs;
22
23 // Optimization level used when emitting IR.
24 static constexpr const char *FrontendOptLevelFlag = "-O3";
25
26 Dispatcher &Dispatch;
27 std::unique_ptr<CompiledLibrary> Library;
28 bool IsCompiled = false;
29
30 // TODO: We don't cache CodeInstances so if a user re-creates the exact same
31 // instantiation it will create a new CodeInstance. This creation cost is
32 // mitigated because the dispatcher caches the compiled object so we will pay
33 // the overhead for building the instantiation code but not for compilation.
34 // Nevertheless, we return the CodeInstance object when created, so the user
35 // can avoid any re-creation overhead by using the returned object to run or
36 // launch. Re-think caching and dispatchers, and on more restrictive
37 // interfaces.
38 struct CodeInstance {
39 TargetModelType TargetModel;
40 const std::string &TemplateCode;
41 std::vector<std::string> ExtraArgs;
42 std::string InstanceName;
43 std::unique_ptr<CppJitModule> InstanceModule;
44 std::string EntryFuncName;
45 void *FuncPtr = nullptr;
46
47 CodeInstance(TargetModelType TargetModel, const std::string &TemplateCode,
48 const std::vector<std::string> &ExtraArgs,
49 const std::string &InstanceName)
50 : TargetModel(TargetModel), TemplateCode(TemplateCode),
51 ExtraArgs(ExtraArgs), InstanceName(InstanceName) {
52 EntryFuncName = "__jit_instance_" + this->InstanceName;
53 // Replace characters '<', '>', ',' with $ to create a unique for the
54 // entry function.
55 std::replace_if(
56 EntryFuncName.begin(), EntryFuncName.end(),
57 [](char C) { return C == '<' || C == '>' || C == ','; }, '$');
58 }
59
60 // Compile-time type name (no RTTI).
61 template <class T> constexpr std::string_view typeName() {
62 // Apparently we are more interested in clang, but leaving the others for
63 // completeness.
64#if defined(__clang__)
65 // "std::string_view type_name() [T = int]"
66 std::string_view P = __PRETTY_FUNCTION__;
67 auto B = P.find("[T = ") + 5;
68 auto E = P.rfind(']');
69 return P.substr(B, E - B);
70#elif defined(__GNUC__)
71 // "... with T = int; ..."
72 std::string_view P = __PRETTY_FUNCTION__;
73 auto B = P.find("with T = ") + 9;
74 auto E = P.find(';', B);
75 return P.substr(B, E - B);
76#elif defined(_MSC_VER)
77 // "std::string_view __cdecl type_name<int>(void)"
78 std::string_view P = __FUNCSIG__;
79 auto B = P.find("type_name<") + 10;
80 auto E = P.find(">(void)", B);
81 return P.substr(B, B - E);
82#else
83 reportFatalError("Unsupported compiler");
84#endif
85 }
86
87 template <typename RetT, typename... ArgT, std::size_t... I>
88 std::string buildFunctionEntry(std::index_sequence<I...>) {
89 std::stringstream OS;
90
91 OS << "extern \"C\" " << typeName<RetT>() << " "
92 << ((!isHostTargetModel(TargetModel)) ? "__global__ " : "")
93 << EntryFuncName << "(";
94 ((OS << (I ? ", " : "")
95 << typeName<
96 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
97 << " Arg" << I),
98 ...);
99 OS << ')';
100
101 std::string ArgList;
102 ((ArgList += (I == 0 ? "" : ", ") + ("Arg" + std::to_string(I))), ...);
103 OS << "{ ";
104 if constexpr (!std::is_void_v<RetT>) {
105 OS << "return ";
106 }
107 OS << InstanceName << "(";
108 ((OS << (I == 0 ? "" : ", ") << "Arg" << std::to_string(I)), ...);
109 OS << "); }";
110
111 return OS.str();
112 }
113
114 template <typename RetOrSig, typename... ArgT> std::string buildCode() {
115 std::string FunctionCode = buildFunctionEntry<RetOrSig, ArgT...>(
116 std::index_sequence_for<ArgT...>{});
117
118 auto ReplaceAll = [](std::string &S, std::string_view From,
119 std::string_view To) {
120 if (From.empty())
121 return;
122 std::size_t Pos = 0;
123 while ((Pos = S.find(From, Pos)) != std::string::npos) {
124 S.replace(Pos, From.size(), To);
125 // Skip over the just-inserted text.
126 Pos += To.size();
127 }
128 };
129
130 std::string InstanceCode = TemplateCode;
131 // Demote kernels to device function to call the templated instance from
132 // the entry function.
133 ReplaceAll(InstanceCode, "__global__", "__device__");
134 InstanceCode = InstanceCode + FunctionCode;
135
136 return InstanceCode;
137 }
138
139 template <typename RetT, typename... ArgT> void compile() {
140 std::string InstanceCode = buildCode<RetT, ArgT...>();
141 InstanceModule =
142 std::make_unique<CppJitModule>(TargetModel, InstanceCode, ExtraArgs);
143 InstanceModule->compile();
144
145 FuncPtr = InstanceModule->getFunctionAddress(EntryFuncName);
146 }
147
148 template <typename... ArgT>
149 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
150 void *Stream, ArgT... Args) {
151 if (!InstanceModule) {
152 compile<void, ArgT...>();
153 }
154
155 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
156
157 return InstanceModule->launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize,
158 Stream);
159 }
160
161 template <typename RetOrSig, typename... ArgT>
162 RetOrSig run(ArgT &&...Args) {
163 static_assert(!std::is_function_v<RetOrSig>,
164 "Function signature type is not yet supported");
165
166 if (!InstanceModule) {
167 compile<RetOrSig, ArgT...>();
168 }
169
170 if constexpr (std::is_void_v<RetOrSig>)
171 InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr, Args...);
172 else
173 return InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr,
174 Args...);
175 }
176 };
177
178 struct CompilationResult {
179 // Declare Ctx first to ensure it is destroyed after Mod.
180 std::unique_ptr<llvm::LLVMContext> Ctx;
181 std::unique_ptr<llvm::Module> Mod;
182
183 ~CompilationResult();
184 };
185
186 void *getFunctionAddress(const std::string &Name);
187 void launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim,
188 void *KernelArgs[], uint64_t ShmemSize, void *Stream);
189
190protected:
191 CompilationResult compileCppToIR();
193
194public:
195 explicit CppJitModule(TargetModelType TargetModel, const std::string &Code,
196 const std::vector<std::string> &ExtraArgs = {});
197 explicit CppJitModule(const std::string &Target, const std::string &Code,
198 const std::vector<std::string> &ExtraArgs = {});
200
201 void compile();
202
204 if (!IsCompiled)
205 compile();
206
207 if (!Library)
208 reportFatalError("Expected non-null library after compilation");
209
210 return *Library;
211 }
212
213 template <typename... ArgT>
214 auto instantiate(const std::string &FuncName, ArgT... Args) {
215 std::string InstanceName = FuncName + "<";
216 bool First = true;
217 ((InstanceName +=
218 (First ? "" : ",") + std::string(std::forward<ArgT>(Args)),
219 First = false),
220 ...);
221
222 InstanceName += ">";
223
224 return CodeInstance{TargetModel, Code, ExtraArgs, InstanceName};
225 }
226
227 template <typename Sig> struct FunctionHandle;
228 template <typename RetT, typename... ArgT>
229 struct FunctionHandle<RetT(ArgT...)> {
231 void *FuncPtr;
232 explicit FunctionHandle(CppJitModule &M, void *FuncPtr)
233 : M(M), FuncPtr(FuncPtr) {}
234
235 RetT run(ArgT... Args) {
236 if constexpr (std::is_void_v<RetT>) {
237 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
238 std::forward<ArgT>(Args)...);
239 } else {
240 return M.Dispatch.template run<RetT(ArgT...)>(
241 FuncPtr, std::forward<ArgT>(Args)...);
242 }
243 }
244 };
245
246 template <typename Sig> struct KernelHandle;
247 template <typename RetT, typename... ArgT>
248 struct KernelHandle<RetT(ArgT...)> {
250 void *FuncPtr = nullptr;
251 explicit KernelHandle(CppJitModule &M, void *FuncPtr)
252 : M(M), FuncPtr(FuncPtr) {
253 static_assert(std::is_void_v<RetT>, "Kernel function must return void");
254 }
255
256 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
257 void *Stream, ArgT... Args) {
258 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
259
260 return M.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize, Stream);
261 }
262 };
263 template <typename Sig>
264 FunctionHandle<Sig> getFunction(const std::string &Name) {
265 if (!IsCompiled)
266 compile();
267
268 if (!isHostTargetModel(TargetModel))
269 reportFatalError("Error: getFunction() applies only to host modules");
270
271 void *FuncPtr = getFunctionAddress(Name);
272
273 return FunctionHandle<Sig>(*this, FuncPtr);
274 }
275
276 template <typename Sig> KernelHandle<Sig> getKernel(const std::string &Name) {
277 if (!IsCompiled)
278 compile();
279
280 if (TargetModel == TargetModelType::HOST)
281 reportFatalError("Error: getKernel() applies only to device modules");
282
283 void *FuncPtr = getFunctionAddress(Name);
284
285 return KernelHandle<Sig>(*this, FuncPtr);
286 }
287};
288
289} // namespace proteus
290
291#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:22
Definition CppJitModule.h:16
void compile()
Definition CppJitModule.cpp:241
void compileCppToDynamicLibrary()
Definition CppJitModule.cpp:49
auto instantiate(const std::string &FuncName, ArgT... Args)
Definition CppJitModule.h:214
FunctionHandle< Sig > getFunction(const std::string &Name)
Definition CppJitModule.h:264
CompiledLibrary & getLibrary()
Definition CppJitModule.h:203
KernelHandle< Sig > getKernel(const std::string &Name)
Definition CppJitModule.h:276
CompilationResult compileCppToIR()
Definition CppJitModule.cpp:136
Definition Dispatcher.h:74
Definition MemoryCache.h:26
TargetModelType
Definition TargetModel.h:8
static int Pos
Definition JitInterface.h:102
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.cpp:49
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:21
Definition CompiledLibrary.h:18
FunctionHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.h:232
void * FuncPtr
Definition CppJitModule.h:231
CppJitModule & M
Definition CppJitModule.h:230
RetT run(ArgT... Args)
Definition CppJitModule.h:235
Definition CppJitModule.h:227
auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition CppJitModule.h:256
KernelHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.h:251
CppJitModule & M
Definition CppJitModule.h:249
Definition CppJitModule.h:246