Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CppJitModule.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_CPPFRONTEND_HPP
2#define PROTEUS_CPPFRONTEND_HPP
3
5
6#include <algorithm>
7#include <sstream>
8#include <vector>
9
10namespace proteus {
11
12struct CompiledLibrary;
13class HashT;
14
16private:
17 TargetModelType TargetModel;
18 std::string Code;
19 std::unique_ptr<HashT> ModuleHash;
20 std::vector<std::string> ExtraArgs;
21
22 // Optimization level used when emitting IR.
23 static constexpr const char *FrontendOptLevelFlag = "-O3";
24
25 Dispatcher &Dispatch;
26 std::unique_ptr<CompiledLibrary> Library;
27 bool IsCompiled = false;
28
29 // TODO: We don't cache CodeInstances so if a user re-creates the exact same
30 // instantiation it will create a new CodeInstance. This creation cost is
31 // mitigated because the dispatcher caches the compiled object so we will pay
32 // the overhead for building the instantiation code but not for compilation.
33 // Nevertheless, we return the CodeInstance object when created, so the user
34 // can avoid any re-creation overhead by using the returned object to run or
35 // launch. Re-think caching and dispatchers, and on more restrictive
36 // interfaces.
37 struct CodeInstance {
38 TargetModelType TargetModel;
39 const std::string &TemplateCode;
40 std::string InstanceName;
41 std::unique_ptr<CppJitModule> InstanceModule;
42 std::string EntryFuncName;
43 void *FuncPtr = nullptr;
44
45 CodeInstance(TargetModelType TargetModel, const std::string &TemplateCode,
46 const std::string &InstanceName)
47 : TargetModel(TargetModel), TemplateCode(TemplateCode),
48 InstanceName(InstanceName) {
49 EntryFuncName = "__jit_instance_" + this->InstanceName;
50 // Replace characters '<', '>', ',' with $ to create a unique for the
51 // entry function.
52 std::replace_if(
53 EntryFuncName.begin(), EntryFuncName.end(),
54 [](char C) { return C == '<' || C == '>' || C == ','; }, '$');
55 }
56
57 // Compile-time type name (no RTTI).
58 template <class T> constexpr std::string_view typeName() {
59 // Apparently we are more interested in clang, but leaving the others for
60 // completeness.
61#if defined(__clang__)
62 // "std::string_view type_name() [T = int]"
63 std::string_view P = __PRETTY_FUNCTION__;
64 auto B = P.find("[T = ") + 5;
65 auto E = P.rfind(']');
66 return P.substr(B, E - B);
67#elif defined(__GNUC__)
68 // "... with T = int; ..."
69 std::string_view P = __PRETTY_FUNCTION__;
70 auto B = P.find("with T = ") + 9;
71 auto E = P.find(';', B);
72 return P.substr(B, E - B);
73#elif defined(_MSC_VER)
74 // "std::string_view __cdecl type_name<int>(void)"
75 std::string_view P = __FUNCSIG__;
76 auto B = P.find("type_name<") + 10;
77 auto E = P.find(">(void)", B);
78 return P.substr(B, B - E);
79#else
80 reportFatalError("Unsupported compiler");
81#endif
82 }
83
84 template <typename RetT, typename... ArgT, std::size_t... I>
85 std::string buildFunctionEntry(std::index_sequence<I...>) {
86 std::stringstream OS;
87
88 OS << "extern \"C\" " << typeName<RetT>() << " "
89 << ((!isHostTargetModel(TargetModel)) ? "__global__ " : "")
90 << EntryFuncName << "(";
91 ((OS << (I ? ", " : "")
92 << typeName<
93 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
94 << " Arg" << I),
95 ...);
96 OS << ')';
97
98 std::string ArgList;
99 ((ArgList += (I == 0 ? "" : ", ") + ("Arg" + std::to_string(I))), ...);
100 OS << "{ ";
101 if constexpr (!std::is_void_v<RetT>) {
102 OS << "return ";
103 }
104 OS << InstanceName << "(";
105 ((OS << (I == 0 ? "" : ", ") << "Arg" << std::to_string(I)), ...);
106 OS << "); }";
107
108 return OS.str();
109 }
110
111 template <typename RetOrSig, typename... ArgT> std::string buildCode() {
112 std::string FunctionCode = buildFunctionEntry<RetOrSig, ArgT...>(
113 std::index_sequence_for<ArgT...>{});
114
115 auto ReplaceAll = [](std::string &S, std::string_view From,
116 std::string_view To) {
117 if (From.empty())
118 return;
119 std::size_t Pos = 0;
120 while ((Pos = S.find(From, Pos)) != std::string::npos) {
121 S.replace(Pos, From.size(), To);
122 // Skip over the just-inserted text.
123 Pos += To.size();
124 }
125 };
126
127 std::string InstanceCode = TemplateCode;
128 // Demote kernels to device function to call the templated instance from
129 // the entry function.
130 ReplaceAll(InstanceCode, "__global__", "__device__");
132
133 return InstanceCode;
134 }
135
136 template <typename RetT, typename... ArgT> void compile() {
137 std::string InstanceCode = buildCode<RetT, ArgT...>();
138 InstanceModule =
139 std::make_unique<CppJitModule>(TargetModel, InstanceCode);
140 InstanceModule->compile();
141
142 FuncPtr = InstanceModule->getFunctionAddress(EntryFuncName);
143 }
144
145 template <typename... ArgT>
146 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
147 void *Stream, ArgT... Args) {
148 if (!InstanceModule) {
149 compile<void, ArgT...>();
150 }
151
152 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
153
154 return InstanceModule->launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize,
155 Stream);
156 }
157
158 template <typename RetOrSig, typename... ArgT>
159 RetOrSig run(ArgT &&...Args) {
160 static_assert(!std::is_function_v<RetOrSig>,
161 "Function signature type is not yet supported");
162
163 if (!InstanceModule) {
164 compile<RetOrSig, ArgT...>();
165 }
166
167 if constexpr (std::is_void_v<RetOrSig>)
168 InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr, Args...);
169 else
170 return InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr,
171 Args...);
172 }
173 };
174
175 struct CompilationResult {
176 // Declare Ctx first to ensure it is destroyed after Mod.
177 std::unique_ptr<LLVMContext> Ctx;
178 std::unique_ptr<Module> Mod;
179
180 ~CompilationResult();
181 };
182
183 void *getFunctionAddress(const std::string &Name);
184 void launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim,
185 void *KernelArgs[], uint64_t ShmemSize, void *Stream);
186
187protected:
188 CompilationResult compileCppToIR();
190
191public:
192 explicit CppJitModule(TargetModelType TargetModel, const std::string &Code,
193 const std::vector<std::string> &ExtraArgs = {});
194 explicit CppJitModule(const std::string &Target, const std::string &Code,
195 const std::vector<std::string> &ExtraArgs = {});
197
198 void compile();
199
201 if (!IsCompiled)
202 compile();
203
204 if (!Library)
205 reportFatalError("Expected non-null library after compilation");
206
207 return *Library;
208 }
209
210 template <typename... ArgT>
211 auto instantiate(const std::string &FuncName, ArgT... Args) {
212 std::string InstanceName = FuncName + "<";
213 bool First = true;
214 ((InstanceName +=
215 (First ? "" : ",") + std::string(std::forward<ArgT>(Args)),
216 First = false),
217 ...);
218
219 InstanceName += ">";
220
221 return CodeInstance{TargetModel, Code, InstanceName};
222 }
223
224 template <typename Sig> struct FunctionHandle;
225 template <typename RetT, typename... ArgT>
226 struct FunctionHandle<RetT(ArgT...)> {
228 void *FuncPtr;
229 explicit FunctionHandle(CppJitModule &M, void *FuncPtr)
230 : M(M), FuncPtr(FuncPtr) {}
231
232 RetT run(ArgT... Args) {
233 if constexpr (std::is_void_v<RetT>) {
234 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
235 std::forward<ArgT>(Args)...);
236 } else {
237 return M.Dispatch.template run<RetT(ArgT...)>(
238 FuncPtr, std::forward<ArgT>(Args)...);
239 }
240 }
241 };
242
243 template <typename Sig> struct KernelHandle;
244 template <typename RetT, typename... ArgT>
245 struct KernelHandle<RetT(ArgT...)> {
247 void *FuncPtr = nullptr;
248 explicit KernelHandle(CppJitModule &M, void *FuncPtr)
249 : M(M), FuncPtr(FuncPtr) {
250 static_assert(std::is_void_v<RetT>, "Kernel function must return void");
251 }
252
254 void *Stream, ArgT... Args) {
255 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
256
257 return M.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize, Stream);
258 }
259 };
260 template <typename Sig>
261 FunctionHandle<Sig> getFunction(const std::string &Name) {
262 if (!IsCompiled)
263 compile();
264
265 if (!isHostTargetModel(TargetModel))
266 reportFatalError("Error: getFunction() applies only to host modules");
267
268 void *FuncPtr = getFunctionAddress(Name);
269
270 return FunctionHandle<Sig>(*this, FuncPtr);
271 }
272
273 template <typename Sig> KernelHandle<Sig> getKernel(const std::string &Name) {
274 if (!IsCompiled)
275 compile();
276
277 if (TargetModel == TargetModelType::HOST)
278 reportFatalError("Error: getKernel() applies only to device modules");
279
280 void *FuncPtr = getFunctionAddress(Name);
281
282 return KernelHandle<Sig>(*this, FuncPtr);
283 }
284};
285
286} // namespace proteus
287
288#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:20
Definition CppJitModule.hpp:15
void compile()
Definition CppJitModule.cpp:240
void compileCppToDynamicLibrary()
Definition CppJitModule.cpp:48
auto instantiate(const std::string &FuncName, ArgT... Args)
Definition CppJitModule.hpp:211
FunctionHandle< Sig > getFunction(const std::string &Name)
Definition CppJitModule.hpp:261
CompiledLibrary & getLibrary()
Definition CppJitModule.hpp:200
KernelHandle< Sig > getKernel(const std::string &Name)
Definition CppJitModule.hpp:273
CompilationResult compileCppToIR()
Definition CppJitModule.cpp:135
Definition Dispatcher.hpp:60
Definition ObjectCacheChain.cpp:26
TargetModelType
Definition TargetModel.hpp:12
static int Pos
Definition JitInterface.hpp:105
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.hpp:53
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:113
Definition Dispatcher.hpp:19
Definition CompiledLibrary.hpp:18
FunctionHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.hpp:229
void * FuncPtr
Definition CppJitModule.hpp:228
CppJitModule & M
Definition CppJitModule.hpp:227
RetT run(ArgT... Args)
Definition CppJitModule.hpp:232
Definition CppJitModule.hpp:224
auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition CppJitModule.hpp:253
KernelHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.hpp:248
CppJitModule & M
Definition CppJitModule.hpp:246
Definition CppJitModule.hpp:243