17 std::unique_ptr<MemoryBuffer> ObjectModule;
18 bool IsCompiled =
false;
30 StringRef TemplateCode;
31 std::string InstanceName;
32 std::unique_ptr<CppJitModule> InstanceModule;
33 std::string EntryFuncName;
36 StringRef InstanceName)
37 : TargetModel(TargetModel), TemplateCode(TemplateCode),
38 InstanceName(InstanceName) {
39 EntryFuncName =
"__jit_instance_" + this->InstanceName;
43 EntryFuncName.begin(), EntryFuncName.end(),
44 [](
char C) { return C ==
'<' || C ==
'>' || C ==
','; },
'$');
48 template <
class T>
constexpr std::string_view typeName() {
53 std::string_view P = __PRETTY_FUNCTION__;
54 auto B = P.find(
"[T = ") + 5;
55 auto E = P.rfind(
']');
56 return P.substr(B, E - B);
57#elif defined(__GNUC__)
59 std::string_view P = __PRETTY_FUNCTION__;
60 auto B = P.find(
"with T = ") + 9;
61 auto E = P.find(
';', B);
62 return P.substr(B, E - B);
63#elif defined(_MSC_VER)
65 std::string_view P = __FUNCSIG__;
66 auto B = P.find(
"type_name<") + 10;
67 auto E = P.find(
">(void)", B);
68 return P.substr(B, B - E);
74 template <
typename RetT,
typename... ArgT, std::size_t... I>
75 std::string buildFunctionEntry(std::index_sequence<I...>) {
76 SmallString<256> FuncS;
77 raw_svector_ostream OS(FuncS);
79 OS <<
"extern \"C\" " << typeName<RetT>() <<
" "
81 << EntryFuncName <<
"(";
82 ((OS << (I ?
", " :
"")
84 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
90 ((ArgList += (I == 0 ?
"" :
", ") + (
"Arg" + std::to_string(I))), ...);
92 if constexpr (!std::is_void_v<RetT>) {
95 OS << InstanceName <<
"(";
96 ((OS << (I == 0 ?
"" :
", ") <<
"Arg" << std::to_string(I)), ...);
99 return std::string(FuncS);
102 template <
typename RetOrSig,
typename... ArgT> std::string buildCode() {
103 std::string FunctionCode = buildFunctionEntry<RetOrSig, ArgT...>(
104 std::index_sequence_for<ArgT...>{});
106 auto ReplaceAll = [](std::string &S, std::string_view From,
107 std::string_view To) {
111 while ((
Pos = S.find(From,
Pos)) != std::string::npos) {
112 S.replace(
Pos, From.size(), To);
118 std::string InstanceCode = TemplateCode.str();
121 ReplaceAll(InstanceCode,
"__global__",
"__device__");
122 InstanceCode = InstanceCode + FunctionCode;
127 template <
typename... ArgT>
129 void *Stream, ArgT...
Args) {
130 if (!InstanceModule) {
131 std::string InstanceCode = buildCode<void, ArgT...>();
133 std::make_unique<CppJitModule>(TargetModel, InstanceCode);
134 InstanceModule->compile();
137 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
139 return InstanceModule->Dispatch.launch(
140 EntryFuncName, GridDim, BlockDim, Ptrs, ShmemSize, Stream,
141 InstanceModule->getObjectModuleRef());
144 template <
typename RetOrSig,
typename... ArgT>
auto run(ArgT &&...
Args) {
145 static_assert(!std::is_function_v<RetOrSig>,
146 "Function signature type is not yet supported");
148 if (!InstanceModule) {
149 std::string InstanceCode = buildCode<RetOrSig, ArgT...>();
151 std::make_unique<CppJitModule>(TargetModel, InstanceCode);
152 InstanceModule->compile();
155 return InstanceModule->Dispatch.run<RetOrSig, ArgT...>(
156 EntryFuncName, InstanceModule->getObjectModuleRef(),
157 std::forward<ArgT>(
Args)...);
161 struct CompilationResult {
163 std::unique_ptr<LLVMContext> Ctx;
164 std::unique_ptr<Module> Mod;
172 explicit CppJitModule(StringRef Target, StringRef Code);
180 return ObjectModule->getMemBufferRef();
183 template <
typename RetOrSig,
typename... ArgT>
184 auto run(
const char *FuncName, ArgT &&...
Args) {
192 std::forward<ArgT>(
Args)...);
195 template <
typename... ArgT>
197 uint64_t ShmemSize,
void *Stream, ArgT...
Args) {
201 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
210 template <
typename... ArgT>
212 [[maybe_unused]] ArgT...
Args) {
213 std::string InstanceName = FuncName.str() +
"<";
216 (First ?
"" :
",") + std::string(std::forward<ArgT>(
Args)),
222 return CodeInstance{TargetModel, Code, InstanceName};