Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CppJitModule.h
Go to the documentation of this file.
1#ifndef PROTEUS_CPPFRONTEND_H
2#define PROTEUS_CPPFRONTEND_H
3
7
8#include <algorithm>
9#include <optional>
10#include <sstream>
11#include <unordered_map>
12#include <vector>
13
14namespace proteus {
15
16struct CompiledLibrary;
17class HashT;
18
20private:
21 TargetModelType TargetModel;
22 std::string Code;
23 std::unique_ptr<HashT> ModuleHash;
24 std::vector<std::string> ExtraArgs;
25 CppJitCompilerBackend CompilerBackend;
26
27 Dispatcher &Dispatch;
28 std::unique_ptr<CompiledLibrary> Library;
29 bool IsCompiled = false;
30
31 struct FuncAttributeState {
32 std::optional<int> MaxDynamicSharedMemorySize;
33
34 void set(JitFuncAttribute Attr, int Value) {
35 if (Value < 0)
36 reportFatalError("Function attribute value must be non-negative");
37
38 switch (Attr) {
41 return;
42 }
43
44 reportFatalError("Unsupported function attribute");
45 }
46 };
47
48 // TODO: We don't cache CodeInstances so if a user re-creates the exact same
49 // instantiation it will create a new CodeInstance. This creation cost is
50 // mitigated because the dispatcher caches the compiled object so we will pay
51 // the overhead for building the instantiation code but not for compilation.
52 // Nevertheless, we return the CodeInstance object when created, so the user
53 // can avoid any re-creation overhead by using the returned object to run or
54 // launch. Re-think caching and dispatchers, and on more restrictive
55 // interfaces.
56 struct CodeInstance {
57 TargetModelType TargetModel;
58 const std::string &TemplateCode;
59 std::vector<std::string> ExtraArgs;
60 CppJitCompilerBackend CompilerBackend;
61 std::string InstanceName;
62 std::unique_ptr<CppJitModule> InstanceModule;
63 std::string EntryFuncName;
64 void *FuncPtr = nullptr;
65 FuncAttributeState FuncAttributes;
66
67 CodeInstance(TargetModelType TargetModel, const std::string &TemplateCode,
68 const std::vector<std::string> &ExtraArgs,
69 CppJitCompilerBackend CompilerBackend,
70 const std::string &InstanceName,
71 FuncAttributeState FuncAttributes = {})
72 : TargetModel(TargetModel), TemplateCode(TemplateCode),
73 ExtraArgs(ExtraArgs), CompilerBackend(CompilerBackend),
74 InstanceName(InstanceName),
75 FuncAttributes(std::move(FuncAttributes)) {
76 EntryFuncName = "__proteus_instance_" + this->InstanceName;
77 // Replace characters '<', '>', ',' with $ to create a unique for the
78 // entry function.
79 std::replace_if(
80 EntryFuncName.begin(), EntryFuncName.end(),
81 [](char C) { return C == '<' || C == '>' || C == ','; }, '$');
82 }
83
84 bool useHostLaunchEntry() const {
85 // Instantiated GPU kernels compile through the host+device path.
86 return TargetModel == TargetModelType::CUDA ||
87 TargetModel == TargetModelType::HIP;
88 }
89
90 TargetModelType getInstanceTargetModel() const {
91 if (TargetModel == TargetModelType::CUDA) {
93 }
94
95 if (TargetModel == TargetModelType::HIP) {
97 }
98
99 return TargetModel;
100 }
101
102 CppJitCompilerBackend getInstanceCompilerBackend() const {
103 return CompilerBackend;
104 }
105
106 const char *getGPUStreamType() const {
107 if (TargetModel == TargetModelType::CUDA) {
108 return "cudaStream_t";
109 }
110
111 if (TargetModel == TargetModelType::HIP) {
112 return "hipStream_t";
113 }
114
115 reportFatalError("Expected CUDA or HIP target model for host launcher");
116 }
117
118 const char *getGPUGetLastErrorName() const {
119 if (TargetModel == TargetModelType::CUDA) {
120 return "cudaGetLastError";
121 }
122
123 if (TargetModel == TargetModelType::HIP) {
124 return "hipGetLastError";
125 }
126
127 reportFatalError("Expected CUDA or HIP target model for host launcher");
128 }
129
130 const char *getGPUFuncSetAttributeName() const {
131 if (TargetModel == TargetModelType::CUDA) {
132 return "cudaFuncSetAttribute";
133 }
134
135 if (TargetModel == TargetModelType::HIP) {
136 return "hipFuncSetAttribute";
137 }
138
139 reportFatalError("Expected CUDA or HIP target model for host launcher");
140 }
141
142 const char *getGPUMaxDynamicSharedMemoryAttrName() const {
143 if (TargetModel == TargetModelType::CUDA) {
144 return "cudaFuncAttributeMaxDynamicSharedMemorySize";
145 }
146
147 if (TargetModel == TargetModelType::HIP) {
148 return "hipFuncAttributeMaxDynamicSharedMemorySize";
149 }
150
151 reportFatalError("Expected CUDA or HIP target model for host launcher");
152 }
153
154 // Compile-time type name (no RTTI).
155 template <class T> constexpr std::string_view typeName() {
156 // Apparently we are more interested in clang, but leaving the others for
157 // completeness.
158#if defined(__clang__)
159 // "std::string_view type_name() [T = int]"
160 std::string_view P = __PRETTY_FUNCTION__;
161 auto B = P.find("[T = ") + 5;
162 auto E = P.rfind(']');
163 return P.substr(B, E - B);
164#elif defined(__GNUC__)
165 // "... with T = int; ..."
166 std::string_view P = __PRETTY_FUNCTION__;
167 auto B = P.find("with T = ") + 9;
168 auto E = P.find(';', B);
169 return P.substr(B, E - B);
170#elif defined(_MSC_VER)
171 // "std::string_view __cdecl type_name<int>(void)"
172 std::string_view P = __FUNCSIG__;
173 auto B = P.find("type_name<") + 10;
174 auto E = P.find(">(void)", B);
175 return P.substr(B, B - E);
176#else
177 reportFatalError("Unsupported compiler");
178#endif
179 }
180
181 template <typename RetT, typename... ArgT, std::size_t... I>
182 std::string buildFunctionEntry(std::index_sequence<I...>) {
183 std::stringstream OS;
184
185 OS << "extern \"C\" " << typeName<RetT>() << " "
186 << ((!isHostTargetModel(TargetModel)) ? "__global__ " : "")
187 << EntryFuncName << "(";
188 ((OS << (I ? ", " : "")
189 << typeName<
190 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
191 << " Arg" << I),
192 ...);
193 OS << ')';
194
195 std::string ArgList;
196 ((ArgList += (I == 0 ? "" : ", ") + ("Arg" + std::to_string(I))), ...);
197 OS << "{ ";
198 if constexpr (!std::is_void_v<RetT>) {
199 OS << "return ";
200 }
201 OS << InstanceName << "(";
202 ((OS << (I == 0 ? "" : ", ") << "Arg" << std::to_string(I)), ...);
203 OS << "); }";
204
205 return OS.str();
206 }
207
208 template <typename... ArgT, std::size_t... I>
209 std::string buildGPUHostLauncher(std::index_sequence<I...>) {
210 std::stringstream OS;
211
212 OS << "extern \"C\" int " << EntryFuncName
213 << "(unsigned GridX, unsigned GridY, unsigned GridZ, unsigned "
214 "BlockX, unsigned BlockY, unsigned BlockZ, size_t ShmemSize, "
215 "void *Stream";
216 ((OS << ", "
217 << typeName<
218 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
219 << " Arg" << I),
220 ...);
221 OS << ") { ";
222 if (FuncAttributes.MaxDynamicSharedMemorySize) {
223 OS << "{ auto AttrErr = " << getGPUFuncSetAttributeName()
224 << "((const void *)" << InstanceName << ", "
225 << getGPUMaxDynamicSharedMemoryAttrName() << ", "
226 << *FuncAttributes.MaxDynamicSharedMemorySize
227 << "); if (AttrErr != 0) return static_cast<int>(AttrErr); } ";
228 }
229 // Use a typed launch so implicit C++ conversions still happen at the
230 // call site.
231 OS << InstanceName
232 << "<<<dim3(GridX, GridY, GridZ), dim3(BlockX, BlockY, BlockZ), "
233 "ShmemSize, static_cast<"
234 << getGPUStreamType() << ">(Stream)>>>(";
235 ((OS << (I == 0 ? "" : ", ") << "Arg" << I), ...);
236 OS << "); return static_cast<int>(" << getGPUGetLastErrorName()
237 << "()); }";
238
239 return OS.str();
240 }
241
242 template <typename RetOrSig, typename... ArgT> std::string buildCode() {
243 if constexpr (std::is_void_v<RetOrSig>) {
244 if (useHostLaunchEntry()) {
245 return TemplateCode + buildGPUHostLauncher<ArgT...>(
246 std::index_sequence_for<ArgT...>{});
247 }
248 }
249
250 std::string FunctionCode = buildFunctionEntry<RetOrSig, ArgT...>(
251 std::index_sequence_for<ArgT...>{});
252
253 auto ReplaceAll = [](std::string &S, std::string_view From,
254 std::string_view To) {
255 if (From.empty())
256 return;
257 std::size_t Pos = 0;
258 while ((Pos = S.find(From, Pos)) != std::string::npos) {
259 S.replace(Pos, From.size(), To);
260 // Skip over the just-inserted text.
261 Pos += To.size();
262 }
263 };
264
265 std::string InstanceCode = TemplateCode;
266 // Demote kernels to device function to call the templated instance from
267 // the entry function.
268 ReplaceAll(InstanceCode, "__global__", "__device__");
269 InstanceCode = InstanceCode + FunctionCode;
270
271 return InstanceCode;
272 }
273
274 template <typename RetT, typename... ArgT> void compile() {
275 std::string InstanceCode = buildCode<RetT, ArgT...>();
276 InstanceModule = std::make_unique<CppJitModule>(
277 getInstanceTargetModel(), InstanceCode, ExtraArgs,
278 getInstanceCompilerBackend());
279 InstanceModule->compile();
280
281 FuncPtr = InstanceModule->getFunctionAddress(EntryFuncName);
282 }
283
284 void setFuncAttribute(JitFuncAttribute Attr, int Value) {
285 FuncAttributes.set(Attr, Value);
286 InstanceModule.reset();
287 FuncPtr = nullptr;
288 }
289
290 template <typename... ArgT>
291 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
292 void *Stream, ArgT... Args) {
293 if (!InstanceModule) {
294 compile<void, ArgT...>();
295 }
296
297 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
298
299 if (useHostLaunchEntry()) {
300 using LauncherFunc =
301 int(unsigned, unsigned, unsigned, unsigned, unsigned, unsigned,
302 std::size_t, void *, ArgT...);
303 return DispatchResult{InstanceModule->Dispatch.run<LauncherFunc>(
304 FuncPtr, GridDim.X, GridDim.Y, GridDim.Z, BlockDim.X, BlockDim.Y,
305 BlockDim.Z, static_cast<std::size_t>(ShmemSize), Stream, Args...)};
306 }
307
308 return InstanceModule->launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize,
309 Stream);
310 }
311
312 template <typename RetOrSig, typename... ArgT>
313 RetOrSig run(ArgT &&...Args) {
314 static_assert(!std::is_function_v<RetOrSig>,
315 "Function signature type is not yet supported");
316
317 if (!InstanceModule) {
318 compile<RetOrSig, ArgT...>();
319 }
320
321 if constexpr (std::is_void_v<RetOrSig>)
322 InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr, Args...);
323 else
324 return InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr,
325 Args...);
326 }
327 };
328
329 std::unordered_map<std::string, std::unique_ptr<CodeInstance>>
330 InstantiationCache;
331
332public:
333 explicit CppJitModule(
334 TargetModelType TargetModel, const std::string &Code,
335 const std::vector<std::string> &ExtraArgs = {},
337 explicit CppJitModule(
338 const std::string &Target, const std::string &Code,
339 const std::vector<std::string> &ExtraArgs = {},
342
343 void compile();
344
345 // Expose the target model so higher-level bindings can reject invalid API
346 // combinations before dispatch.
347 TargetModelType getTargetModel() const { return TargetModel; }
348
349 void setFuncAttribute(void *KernelFunc, JitFuncAttribute Attr, int Value);
350 void *getFunctionAddress(const std::string &Name);
351 void *getKernelAddress(const std::string &Name) {
352 if (!IsCompiled)
353 compile();
354
355 if (TargetModel == TargetModelType::HOST)
357 "Error: getKernelAddress() applies only to device modules");
358
359 // Kernel symbols are loaded through the same compiled image as host
360 // function symbols once target validation is complete.
361 return getFunctionAddress(Name);
362 }
363
364 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
365 LaunchDims BlockDim, void *KernelArgs[],
366 uint64_t ShmemSize, void *Stream);
367
369 if (!IsCompiled)
370 compile();
371
372 if (!Library)
373 reportFatalError("Expected non-null library after compilation");
374
375 return *Library;
376 }
377
378 template <typename... ArgT>
379 auto &instantiate(const std::string &FuncName, ArgT... Args) {
380 std::string InstanceName = FuncName + "<";
381 bool First = true;
382 ((InstanceName +=
383 (First ? "" : ",") + std::string(std::forward<ArgT>(Args)),
384 First = false),
385 ...);
386
387 InstanceName += ">";
388
389 auto It = InstantiationCache.find(InstanceName);
390 if (It != InstantiationCache.end()) {
391 return *It->second;
392 }
393
394 auto [NewIt, OK] = InstantiationCache.emplace(
395 InstanceName,
396 std::make_unique<CodeInstance>(TargetModel, Code, ExtraArgs,
397 CompilerBackend, InstanceName));
398 return *NewIt->second;
399 }
400
401 template <typename Sig> struct FunctionHandle;
402 template <typename RetT, typename... ArgT>
403 struct FunctionHandle<RetT(ArgT...)> {
405 void *FuncPtr;
406 explicit FunctionHandle(CppJitModule &M, void *FuncPtr)
407 : M(M), FuncPtr(FuncPtr) {}
408
409 RetT run(ArgT... Args) {
410 if constexpr (std::is_void_v<RetT>) {
411 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
412 std::forward<ArgT>(Args)...);
413 } else {
414 return M.Dispatch.template run<RetT(ArgT...)>(
415 FuncPtr, std::forward<ArgT>(Args)...);
416 }
417 }
418 };
419
420 template <typename Sig> struct KernelHandle;
421 template <typename RetT, typename... ArgT>
422 struct KernelHandle<RetT(ArgT...)> {
424 void *FuncPtr = nullptr;
425 explicit KernelHandle(CppJitModule &M, void *FuncPtr)
426 : M(M), FuncPtr(FuncPtr) {
427 static_assert(std::is_void_v<RetT>, "Kernel function must return void");
428 }
429
430 void setFuncAttribute(JitFuncAttribute Attr, int Value) {
431 M.setFuncAttribute(FuncPtr, Attr, Value);
432 }
433
434 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
435 void *Stream, ArgT... Args) {
436 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
437
438 return M.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize, Stream);
439 }
440 };
441 template <typename Sig>
442 FunctionHandle<Sig> getFunction(const std::string &Name) {
443 if (!IsCompiled)
444 compile();
445
446 if (!isHostTargetModel(TargetModel))
447 reportFatalError("Error: getFunction() applies only to host modules");
448
449 void *FuncPtr = getFunctionAddress(Name);
450
451 return FunctionHandle<Sig>(*this, FuncPtr);
452 }
453
454 template <typename Sig> KernelHandle<Sig> getKernel(const std::string &Name) {
455 if (!IsCompiled)
456 compile();
457
458 if (TargetModel == TargetModelType::HOST)
459 reportFatalError("Error: getKernel() applies only to device modules");
460
461 void *FuncPtr = getFunctionAddress(Name);
462
463 return KernelHandle<Sig>(*this, FuncPtr);
464 }
465};
466
467} // namespace proteus
468
469#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:23
Definition CppJitModule.h:19
auto & instantiate(const std::string &FuncName, ArgT... Args)
Definition CppJitModule.h:379
void compile()
Definition CppJitModule.cpp:36
void * getFunctionAddress(const std::string &Name)
Definition CppJitModule.cpp:85
FunctionHandle< Sig > getFunction(const std::string &Name)
Definition CppJitModule.h:442
TargetModelType getTargetModel() const
Definition CppJitModule.h:347
void setFuncAttribute(void *KernelFunc, JitFuncAttribute Attr, int Value)
Definition CppJitModule.cpp:80
CompiledLibrary & getLibrary()
Definition CppJitModule.h:368
KernelHandle< Sig > getKernel(const std::string &Name)
Definition CppJitModule.h:454
DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, void *KernelArgs[], uint64_t ShmemSize, void *Stream)
Definition CppJitModule.cpp:90
void * getKernelAddress(const std::string &Name)
Definition CppJitModule.h:351
Definition Dispatcher.h:75
Definition MemoryCache.h:27
JitFuncAttribute
Definition JitFuncAttribute.h:6
TargetModelType
Definition TargetModel.h:8
static int Pos
Definition JitInterface.h:102
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.cpp:49
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
CppJitCompilerBackend
Definition CppJitCompilerBackend.h:6
void setFuncAttribute(TargetModelType TargetModel, void *KernelFunc, JitFuncAttribute Attr, int Value)
Definition JitFuncAttribute.cpp:15
Definition Dispatcher.h:22
unsigned Z
Definition Dispatcher.h:23
unsigned Y
Definition Dispatcher.h:23
unsigned X
Definition Dispatcher.h:23
Definition CompiledLibrary.h:18
FunctionHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.h:406
void * FuncPtr
Definition CppJitModule.h:405
CppJitModule & M
Definition CppJitModule.h:404
RetT run(ArgT... Args)
Definition CppJitModule.h:409
Definition CppJitModule.h:401
auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition CppJitModule.h:434
KernelHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.h:425
void setFuncAttribute(JitFuncAttribute Attr, int Value)
Definition CppJitModule.h:430
CppJitModule & M
Definition CppJitModule.h:423
Definition CppJitModule.h:420
Definition Dispatcher.h:53