20 std::unique_ptr<HashT> ModuleHash;
21 std::vector<std::string> ExtraArgs;
24 static constexpr const char *FrontendOptLevelFlag =
"-O3";
27 std::unique_ptr<CompiledLibrary> Library;
28 bool IsCompiled =
false;
40 const std::string &TemplateCode;
41 std::vector<std::string> ExtraArgs;
42 std::string InstanceName;
43 std::unique_ptr<CppJitModule> InstanceModule;
44 std::string EntryFuncName;
45 void *FuncPtr =
nullptr;
47 CodeInstance(
TargetModelType TargetModel,
const std::string &TemplateCode,
48 const std::vector<std::string> &ExtraArgs,
49 const std::string &InstanceName)
50 : TargetModel(TargetModel), TemplateCode(TemplateCode),
51 ExtraArgs(ExtraArgs), InstanceName(InstanceName) {
52 EntryFuncName =
"__jit_instance_" + this->InstanceName;
56 EntryFuncName.begin(), EntryFuncName.end(),
57 [](
char C) { return C ==
'<' || C ==
'>' || C ==
','; },
'$');
61 template <
class T>
constexpr std::string_view typeName() {
66 std::string_view P = __PRETTY_FUNCTION__;
67 auto B = P.find(
"[T = ") + 5;
68 auto E = P.rfind(
']');
69 return P.substr(B, E - B);
70#elif defined(__GNUC__)
72 std::string_view P = __PRETTY_FUNCTION__;
73 auto B = P.find(
"with T = ") + 9;
74 auto E = P.find(
';', B);
75 return P.substr(B, E - B);
76#elif defined(_MSC_VER)
78 std::string_view P = __FUNCSIG__;
79 auto B = P.find(
"type_name<") + 10;
80 auto E = P.find(
">(void)", B);
81 return P.substr(B, B - E);
87 template <
typename RetT,
typename... ArgT, std::size_t... I>
88 std::string buildFunctionEntry(std::index_sequence<I...>) {
91 OS <<
"extern \"C\" " << typeName<RetT>() <<
" "
93 << EntryFuncName <<
"(";
94 ((OS << (I ?
", " :
"")
96 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
102 ((ArgList += (I == 0 ?
"" :
", ") + (
"Arg" + std::to_string(I))), ...);
104 if constexpr (!std::is_void_v<RetT>) {
107 OS << InstanceName <<
"(";
108 ((OS << (I == 0 ?
"" :
", ") <<
"Arg" << std::to_string(I)), ...);
114 template <
typename RetOrSig,
typename... ArgT> std::string buildCode() {
115 std::string FunctionCode = buildFunctionEntry<RetOrSig, ArgT...>(
116 std::index_sequence_for<ArgT...>{});
118 auto ReplaceAll = [](std::string &S, std::string_view From,
119 std::string_view To) {
123 while ((
Pos = S.find(From,
Pos)) != std::string::npos) {
124 S.replace(
Pos, From.size(), To);
130 std::string InstanceCode = TemplateCode;
133 ReplaceAll(InstanceCode,
"__global__",
"__device__");
134 InstanceCode = InstanceCode + FunctionCode;
139 template <
typename RetT,
typename... ArgT>
void compile() {
140 std::string InstanceCode = buildCode<RetT, ArgT...>();
142 std::make_unique<CppJitModule>(TargetModel, InstanceCode, ExtraArgs);
143 InstanceModule->compile();
145 FuncPtr = InstanceModule->getFunctionAddress(EntryFuncName);
148 template <
typename... ArgT>
150 void *Stream, ArgT...
Args) {
151 if (!InstanceModule) {
152 compile<void, ArgT...>();
155 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
157 return InstanceModule->launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize,
161 template <
typename RetOrSig,
typename... ArgT>
162 RetOrSig run(ArgT &&...
Args) {
163 static_assert(!std::is_function_v<RetOrSig>,
164 "Function signature type is not yet supported");
166 if (!InstanceModule) {
167 compile<RetOrSig, ArgT...>();
170 if constexpr (std::is_void_v<RetOrSig>)
171 InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr,
Args...);
173 return InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr,
178 struct CompilationResult {
180 std::unique_ptr<llvm::LLVMContext> Ctx;
181 std::unique_ptr<llvm::Module> Mod;
183 ~CompilationResult();
186 void *getFunctionAddress(
const std::string &Name);
188 void *KernelArgs[], uint64_t ShmemSize,
void *Stream);
196 const std::vector<std::string> &ExtraArgs = {});
197 explicit CppJitModule(
const std::string &Target,
const std::string &Code,
198 const std::vector<std::string> &ExtraArgs = {});
213 template <
typename... ArgT>
215 std::string InstanceName = FuncName +
"<";
218 (First ?
"" :
",") + std::string(std::forward<ArgT>(
Args)),
224 return CodeInstance{TargetModel, Code, ExtraArgs, InstanceName};
228 template <
typename RetT,
typename... ArgT>
233 : M(M), FuncPtr(FuncPtr) {}
236 if constexpr (std::is_void_v<RetT>) {
237 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
238 std::forward<ArgT>(
Args)...);
240 return M.Dispatch.template run<RetT(ArgT...)>(
241 FuncPtr, std::forward<ArgT>(
Args)...);
247 template <
typename RetT,
typename... ArgT>
250 void *FuncPtr =
nullptr;
252 : M(M), FuncPtr(FuncPtr) {
253 static_assert(std::is_void_v<RetT>,
"Kernel function must return void");
257 void *Stream, ArgT...
Args) {
258 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
260 return M.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize, Stream);
263 template <
typename Sig>
271 void *FuncPtr = getFunctionAddress(Name);
283 void *FuncPtr = getFunctionAddress(Name);