Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CppJitModule.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_CPPFRONTEND_HPP
2#define PROTEUS_CPPFRONTEND_HPP
3
4#include <llvm/Support/Debug.h>
5
8
9namespace proteus {
10
12private:
13 TargetModelType TargetModel;
14 std::string Code;
15 HashT ModuleHash;
16 std::vector<std::string> ExtraArgs;
17
18 // Optimization level used when emitting IR.
19 static constexpr const char *FrontendOptLevelFlag = "-O3";
20
21 Dispatcher &Dispatch;
22 std::unique_ptr<CompiledLibrary> Library = nullptr;
23 bool IsCompiled = false;
24
25 // TODO: We don't cache CodeInstances so if a user re-creates the exact same
26 // instantiation it will create a new CodeInstance. This creation cost is
27 // mitigated because the dispatcher caches the compiled object so we will pay
28 // the overhead for building the instantiation code but not for compilation.
29 // Nevertheless, we return the CodeInstance object when created, so the user
30 // can avoid any re-creation overhead by using the returned object to run or
31 // launch. Re-think caching and dispatchers, and on more restrictive
32 // interfaces.
33 struct CodeInstance {
34 TargetModelType TargetModel;
35 StringRef TemplateCode;
36 std::string InstanceName;
37 std::unique_ptr<CppJitModule> InstanceModule;
38 std::string EntryFuncName;
39 void *FuncPtr = nullptr;
40
41 CodeInstance(TargetModelType TargetModel, StringRef TemplateCode,
42 StringRef InstanceName)
43 : TargetModel(TargetModel), TemplateCode(TemplateCode),
44 InstanceName(InstanceName) {
45 EntryFuncName = "__jit_instance_" + this->InstanceName;
46 // Replace characters '<', '>', ',' with $ to create a unique for the
47 // entry function.
48 std::replace_if(
49 EntryFuncName.begin(), EntryFuncName.end(),
50 [](char C) { return C == '<' || C == '>' || C == ','; }, '$');
51 }
52
53 // Compile-time type name (no RTTI).
54 template <class T> constexpr std::string_view typeName() {
55 // Apparently we are more interested in clang, but leaving the others for
56 // completeness.
57#if defined(__clang__)
58 // "std::string_view type_name() [T = int]"
59 std::string_view P = __PRETTY_FUNCTION__;
60 auto B = P.find("[T = ") + 5;
61 auto E = P.rfind(']');
62 return P.substr(B, E - B);
63#elif defined(__GNUC__)
64 // "... with T = int; ..."
65 std::string_view P = __PRETTY_FUNCTION__;
66 auto B = P.find("with T = ") + 9;
67 auto E = P.find(';', B);
68 return P.substr(B, E - B);
69#elif defined(_MSC_VER)
70 // "std::string_view __cdecl type_name<int>(void)"
71 std::string_view P = __FUNCSIG__;
72 auto B = P.find("type_name<") + 10;
73 auto E = P.find(">(void)", B);
74 return P.substr(B, B - E);
75#else
76 PROTEUS_FATAL_ERROR("Unsupported compiler");
77#endif
78 }
79
80 template <typename RetT, typename... ArgT, std::size_t... I>
81 std::string buildFunctionEntry(std::index_sequence<I...>) {
84
85 OS << "extern \"C\" " << typeName<RetT>() << " "
86 << ((!isHostTargetModel(TargetModel)) ? "__global__ " : "")
87 << EntryFuncName << "(";
88 ((OS << (I ? ", " : "")
89 << typeName<
90 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
91 << " Arg" << I),
92 ...);
93 OS << ')';
94
95 std::string ArgList;
96 ((ArgList += (I == 0 ? "" : ", ") + ("Arg" + std::to_string(I))), ...);
97 OS << "{ ";
98 if constexpr (!std::is_void_v<RetT>) {
99 OS << "return ";
100 }
101 OS << InstanceName << "(";
102 ((OS << (I == 0 ? "" : ", ") << "Arg" << std::to_string(I)), ...);
103 OS << "); }";
104
105 return std::string(FuncS);
106 }
107
108 template <typename RetOrSig, typename... ArgT> std::string buildCode() {
109 std::string FunctionCode = buildFunctionEntry<RetOrSig, ArgT...>(
110 std::index_sequence_for<ArgT...>{});
111
112 auto ReplaceAll = [](std::string &S, std::string_view From,
113 std::string_view To) {
114 if (From.empty())
115 return;
116 std::size_t Pos = 0;
117 while ((Pos = S.find(From, Pos)) != std::string::npos) {
118 S.replace(Pos, From.size(), To);
119 // Skip over the just-inserted text.
120 Pos += To.size();
121 }
122 };
123
124 std::string InstanceCode = TemplateCode.str();
125 // Demote kernels to device function to call the templated instance from
126 // the entry function.
127 ReplaceAll(InstanceCode, "__global__", "__device__");
129
130 return InstanceCode;
131 }
132
133 template <typename RetT, typename... ArgT> void compile() {
134 std::string InstanceCode = buildCode<RetT, ArgT...>();
135 InstanceModule =
136 std::make_unique<CppJitModule>(TargetModel, InstanceCode);
137 InstanceModule->compile();
138
139 FuncPtr = InstanceModule->Dispatch.getFunctionAddress(
140 EntryFuncName, InstanceModule->ModuleHash,
141 InstanceModule->getLibrary());
142 }
143
144 template <typename... ArgT>
145 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
146 void *Stream, ArgT... Args) {
147 if (!InstanceModule) {
148 compile<void, ArgT...>();
149 }
150
151 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
152
153 return InstanceModule->Dispatch.launch(FuncPtr, GridDim, BlockDim, Ptrs,
155 }
156
157 template <typename RetOrSig, typename... ArgT>
158 RetOrSig run(ArgT &&...Args) {
159 static_assert(!std::is_function_v<RetOrSig>,
160 "Function signature type is not yet supported");
161
162 if (!InstanceModule) {
163 compile<RetOrSig, ArgT...>();
164 }
165
166 if constexpr (std::is_void_v<RetOrSig>)
167 InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr, Args...);
168 else
169 return InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr,
170 Args...);
171 }
172 };
173
174 struct CompilationResult {
175 // Declare Ctx first to ensure it is destroyed after Mod.
176 std::unique_ptr<LLVMContext> Ctx = nullptr;
177 std::unique_ptr<Module> Mod = nullptr;
178 };
179
180protected:
181 CompilationResult compileCppToIR();
183
184public:
185 explicit CppJitModule(TargetModelType TargetModel, StringRef Code,
186 const std::vector<std::string> &ExtraArgs = {});
187 explicit CppJitModule(StringRef Target, StringRef Code,
188 const std::vector<std::string> &ExtraArgs = {});
189
190 void compile();
191
193 if (!IsCompiled)
194 compile();
195
196 if (!Library)
197 PROTEUS_FATAL_ERROR("Expected non-null library after compilation");
198
199 return *Library;
200 }
201
202 template <typename... ArgT>
204 [[maybe_unused]] ArgT... Args) {
205 std::string InstanceName = FuncName.str() + "<";
206 bool First = true;
207 ((InstanceName +=
208 (First ? "" : ",") + std::string(std::forward<ArgT>(Args)),
209 First = false),
210 ...);
211
212 InstanceName += ">";
213
214 return CodeInstance{TargetModel, Code, InstanceName};
215 }
216
217 template <typename Sig> struct FunctionHandle;
218 template <typename RetT, typename... ArgT>
219 struct FunctionHandle<RetT(ArgT...)> {
221 void *FuncPtr;
222 explicit FunctionHandle(CppJitModule &M, void *FuncPtr)
223 : M(M), FuncPtr(FuncPtr) {}
224
225 RetT run(ArgT... Args) {
226 if constexpr (std::is_void_v<RetT>) {
227 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
228 std::forward<ArgT>(Args)...);
229 } else {
230 return M.Dispatch.template run<RetT(ArgT...)>(
231 FuncPtr, std::forward<ArgT>(Args)...);
232 }
233 }
234 };
235
236 template <typename Sig> struct KernelHandle;
237 template <typename RetT, typename... ArgT>
238 struct KernelHandle<RetT(ArgT...)> {
240 void *FuncPtr = nullptr;
241 explicit KernelHandle(CppJitModule &M, void *FuncPtr)
242 : M(M), FuncPtr(FuncPtr) {
243 static_assert(std::is_void_v<RetT>, "Kernel function must return void");
244 }
245
247 void *Stream, ArgT... Args) {
248 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
249
250 return M.Dispatch.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize,
251 Stream);
252 }
253 };
254 template <typename Sig> FunctionHandle<Sig> getFunction(StringRef Name) {
255 if (!IsCompiled)
256 compile();
257
258 if (!isHostTargetModel(TargetModel))
259 PROTEUS_FATAL_ERROR("Error: getFunction() applies only to host modules");
260
261 void *FuncPtr = Dispatch.getFunctionAddress(Name, ModuleHash, getLibrary());
262
263 return FunctionHandle<Sig>(*this, FuncPtr);
264 }
265
266 template <typename Sig> KernelHandle<Sig> getKernel(StringRef Name) {
267 if (!IsCompiled)
268 compile();
269
270 if (TargetModel == TargetModelType::HOST)
271 PROTEUS_FATAL_ERROR("Error: getKernel() applies only to device modules");
272
273 void *FuncPtr = Dispatch.getFunctionAddress(Name, ModuleHash, getLibrary());
274
275 return KernelHandle<Sig>(*this, FuncPtr);
276 }
277};
278
279} // namespace proteus
280
281#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition CppJitModule.hpp:11
FunctionHandle< Sig > getFunction(StringRef Name)
Definition CppJitModule.hpp:254
void compile()
Definition CppJitModule.cpp:226
void compileCppToDynamicLibrary()
Definition CppJitModule.cpp:42
KernelHandle< Sig > getKernel(StringRef Name)
Definition CppJitModule.hpp:266
auto instantiate(StringRef FuncName, ArgT... Args)
Definition CppJitModule.hpp:203
CompiledLibrary & getLibrary()
Definition CppJitModule.hpp:192
CompilationResult compileCppToIR()
Definition CppJitModule.cpp:125
Definition Dispatcher.hpp:54
virtual void * getFunctionAddress(StringRef FunctionName, HashT ModuleHash, CompiledLibrary &Library)=0
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)=0
Definition Hashing.hpp:20
Definition BuiltinsCUDA.cpp:4
TargetModelType
Definition TargetModel.hpp:14
static int Pos
Definition JitInterface.hpp:105
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.hpp:55
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
Definition Dispatcher.hpp:16
Definition CompiledLibrary.hpp:12
FunctionHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.hpp:222
void * FuncPtr
Definition CppJitModule.hpp:221
CppJitModule & M
Definition CppJitModule.hpp:220
RetT run(ArgT... Args)
Definition CppJitModule.hpp:225
Definition CppJitModule.hpp:217
auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition CppJitModule.hpp:246
KernelHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.hpp:241
CppJitModule & M
Definition CppJitModule.hpp:239
Definition CppJitModule.hpp:236