Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CppJitModule.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_CPPFRONTEND_HPP
2#define PROTEUS_CPPFRONTEND_HPP
3
4#include <llvm/Support/Debug.h>
5
7
8namespace proteus {
9
11private:
12 TargetModelType TargetModel;
13 std::string Code;
14 HashT ModuleHash;
15
16 Dispatcher &Dispatch;
17 std::unique_ptr<MemoryBuffer> ObjectModule;
18 bool IsCompiled = false;
19
20 // TODO: We don't cache CodeInstances so if a user re-creates the exact same
21 // instantiation it will create a new CodeInstance. This creation cost is
22 // mitigated because the dispatcher caches the compiled object so we will pay
23 // the overhead for building the instantiation code but not for compilation.
24 // Nevertheless, we return the CodeInstance object when created, so the user
25 // can avoid any re-creation overhead by using the returned object to run or
26 // launch. Re-think caching and dispatchers, and on more restrictive
27 // interfaces.
28 struct CodeInstance {
29 TargetModelType TargetModel;
30 StringRef TemplateCode;
31 std::string InstanceName;
32 std::unique_ptr<CppJitModule> InstanceModule;
33 std::string EntryFuncName;
34
35 CodeInstance(TargetModelType TargetModel, StringRef TemplateCode,
36 StringRef InstanceName)
37 : TargetModel(TargetModel), TemplateCode(TemplateCode),
38 InstanceName(InstanceName) {
39 EntryFuncName = "__jit_instance_" + this->InstanceName;
40 // Replace characters '<', '>', ',' with $ to create a unique for the
41 // entry function.
42 std::replace_if(
43 EntryFuncName.begin(), EntryFuncName.end(),
44 [](char C) { return C == '<' || C == '>' || C == ','; }, '$');
45 }
46
47 // Compile-time type name (no RTTI).
48 template <class T> constexpr std::string_view typeName() {
49 // Apparently we are more interested in clang, but leaving the others for
50 // completeness.
51#if defined(__clang__)
52 // "std::string_view type_name() [T = int]"
53 std::string_view P = __PRETTY_FUNCTION__;
54 auto B = P.find("[T = ") + 5;
55 auto E = P.rfind(']');
56 return P.substr(B, E - B);
57#elif defined(__GNUC__)
58 // "... with T = int; ..."
59 std::string_view P = __PRETTY_FUNCTION__;
60 auto B = P.find("with T = ") + 9;
61 auto E = P.find(';', B);
62 return P.substr(B, E - B);
63#elif defined(_MSC_VER)
64 // "std::string_view __cdecl type_name<int>(void)"
65 std::string_view P = __FUNCSIG__;
66 auto B = P.find("type_name<") + 10;
67 auto E = P.find(">(void)", B);
68 return P.substr(B, B - E);
69#else
70 PROTEUS_FATAL_ERROR("Unsupported compiler");
71#endif
72 }
73
74 template <typename RetT, typename... ArgT, std::size_t... I>
75 std::string buildFunctionEntry(std::index_sequence<I...>) {
76 SmallString<256> FuncS;
77 raw_svector_ostream OS(FuncS);
78
79 OS << "extern \"C\" " << typeName<RetT>() << " "
80 << ((TargetModel != TargetModelType::HOST) ? "__global__ " : "")
81 << EntryFuncName << "(";
82 ((OS << (I ? ", " : "")
83 << typeName<
84 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
85 << " Arg" << I),
86 ...);
87 OS << ')';
88
89 std::string ArgList;
90 ((ArgList += (I == 0 ? "" : ", ") + ("Arg" + std::to_string(I))), ...);
91 OS << "{ ";
92 if constexpr (!std::is_void_v<RetT>) {
93 OS << "return ";
94 }
95 OS << InstanceName << "(";
96 ((OS << (I == 0 ? "" : ", ") << "Arg" << std::to_string(I)), ...);
97 OS << "); }";
98
99 return std::string(FuncS);
100 }
101
102 template <typename RetOrSig, typename... ArgT> std::string buildCode() {
103 std::string FunctionCode = buildFunctionEntry<RetOrSig, ArgT...>(
104 std::index_sequence_for<ArgT...>{});
105
106 auto ReplaceAll = [](std::string &S, std::string_view From,
107 std::string_view To) {
108 if (From.empty())
109 return;
110 std::size_t Pos = 0;
111 while ((Pos = S.find(From, Pos)) != std::string::npos) {
112 S.replace(Pos, From.size(), To);
113 // Skip over the just-inserted text.
114 Pos += To.size();
115 }
116 };
117
118 std::string InstanceCode = TemplateCode.str();
119 // Demote kernels to device function to call the templated instance from
120 // the entry function.
121 ReplaceAll(InstanceCode, "__global__", "__device__");
122 InstanceCode = InstanceCode + FunctionCode;
123
124 return InstanceCode;
125 }
126
127 template <typename... ArgT>
128 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
129 void *Stream, ArgT... Args) {
130 if (!InstanceModule) {
131 std::string InstanceCode = buildCode<void, ArgT...>();
132 InstanceModule =
133 std::make_unique<CppJitModule>(TargetModel, InstanceCode);
134 InstanceModule->compile();
135 }
136
137 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
138
139 return InstanceModule->Dispatch.launch(
140 EntryFuncName, GridDim, BlockDim, Ptrs, ShmemSize, Stream,
141 InstanceModule->getObjectModuleRef());
142 }
143
144 template <typename RetOrSig, typename... ArgT> auto run(ArgT &&...Args) {
145 static_assert(!std::is_function_v<RetOrSig>,
146 "Function signature type is not yet supported");
147
148 if (!InstanceModule) {
149 std::string InstanceCode = buildCode<RetOrSig, ArgT...>();
150 InstanceModule =
151 std::make_unique<CppJitModule>(TargetModel, InstanceCode);
152 InstanceModule->compile();
153 }
154
155 return InstanceModule->Dispatch.run<RetOrSig, ArgT...>(
156 EntryFuncName, InstanceModule->getObjectModuleRef(),
157 std::forward<ArgT>(Args)...);
158 }
159 };
160
161 struct CompilationResult {
162 // Declare Ctx first to ensure it is destroyed after Mod.
163 std::unique_ptr<LLVMContext> Ctx;
164 std::unique_ptr<Module> Mod;
165 };
166
167protected:
168 CompilationResult compileCppToIR();
169
170public:
171 explicit CppJitModule(TargetModelType TargetModel, StringRef Code);
172 explicit CppJitModule(StringRef Target, StringRef Code);
173
174 void compile();
175
176 std::optional<MemoryBufferRef> getObjectModuleRef() const {
177 if (!ObjectModule)
178 return std::nullopt;
179
180 return ObjectModule->getMemBufferRef();
181 }
182
183 template <typename RetOrSig, typename... ArgT>
184 auto run(const char *FuncName, ArgT &&...Args) {
185 if (!IsCompiled)
186 compile();
187
188 // TODO: We could cache the function address if we have a stateful object to
189 // store and return to the user for invocations, similar to DSL
190 // KernelHandle.
191 return Dispatch.run<RetOrSig, ArgT...>(FuncName, getObjectModuleRef(),
192 std::forward<ArgT>(Args)...);
193 }
194
195 template <typename... ArgT>
196 auto launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim,
197 uint64_t ShmemSize, void *Stream, ArgT... Args) {
198 if (!IsCompiled)
199 compile();
200
201 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
202
203 // TODO: We could cache the function address if we have a stateful object to
204 // store and return to the user for invocations, similar to DSL
205 // KernelHandle.
206 return Dispatch.launch(KernelName, GridDim, BlockDim, Ptrs, ShmemSize,
207 Stream, getObjectModuleRef());
208 }
209
210 template <typename... ArgT>
211 auto instantiate([[maybe_unused]] StringRef FuncName,
212 [[maybe_unused]] ArgT... Args) {
213 std::string InstanceName = FuncName.str() + "<";
214 bool First = true;
215 ((InstanceName +=
216 (First ? "" : ",") + std::string(std::forward<ArgT>(Args)),
217 First = false),
218 ...);
219
220 InstanceName += ">";
221
222 return CodeInstance{TargetModel, Code, InstanceName};
223 }
224};
225
226} // namespace proteus
227
228#endif
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
char int void ** Args
Definition CompilerInterfaceHost.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition CppJitModule.hpp:10
auto launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition CppJitModule.hpp:196
void compile()
Definition CppJitModule.cpp:135
auto run(const char *FuncName, ArgT &&...Args)
Definition CppJitModule.hpp:184
auto instantiate(StringRef FuncName, ArgT... Args)
Definition CppJitModule.hpp:211
std::optional< MemoryBufferRef > getObjectModuleRef() const
Definition CppJitModule.hpp:176
CompilationResult compileCppToIR()
Definition CppJitModule.cpp:33
Definition Dispatcher.hpp:46
virtual DispatchResult launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream, std::optional< MemoryBufferRef > ObjectModule)=0
auto run(StringRef FuncName, std::optional< MemoryBufferRef > ObjectModule, ArgT &&...Args)
Definition Dispatcher.hpp:75
Definition Hashing.hpp:20
Definition CppJitModule.cpp:21
TargetModelType
Definition TargetModel.hpp:14
static int Pos
Definition JitInterface.hpp:105
Definition Dispatcher.hpp:15