1#ifndef PROTEUS_CPPFRONTEND_HPP
2#define PROTEUS_CPPFRONTEND_HPP
4#include <llvm/Support/Debug.h>
16 std::vector<std::string> ExtraArgs;
19 static constexpr const char *FrontendOptLevelFlag =
"-O3";
22 std::unique_ptr<CompiledLibrary> Library =
nullptr;
23 bool IsCompiled =
false;
36 std::string InstanceName;
37 std::unique_ptr<CppJitModule> InstanceModule;
38 std::string EntryFuncName;
39 void *FuncPtr =
nullptr;
43 : TargetModel(TargetModel), TemplateCode(TemplateCode),
44 InstanceName(InstanceName) {
45 EntryFuncName =
"__jit_instance_" + this->InstanceName;
49 EntryFuncName.begin(), EntryFuncName.end(),
50 [](
char C) { return C ==
'<' || C ==
'>' || C ==
','; },
'$');
54 template <
class T>
constexpr std::string_view typeName() {
60 auto B =
P.find(
"[T = ") + 5;
61 auto E =
P.rfind(
']');
62 return P.substr(
B,
E -
B);
63#elif defined(__GNUC__)
66 auto B =
P.find(
"with T = ") + 9;
67 auto E =
P.find(
';',
B);
68 return P.substr(
B,
E -
B);
69#elif defined(_MSC_VER)
72 auto B =
P.find(
"type_name<") + 10;
73 auto E =
P.find(
">(void)",
B);
74 return P.substr(
B,
B -
E);
80 template <
typename RetT,
typename...
ArgT, std::size_t...
I>
81 std::string buildFunctionEntry(std::index_sequence<I...>) {
87 << EntryFuncName <<
"(";
88 ((
OS << (
I ?
", " :
"")
90 std::decay_t<std::tuple_element_t<
I, std::tuple<ArgT...>>>>()
96 ((
ArgList += (
I == 0 ?
"" :
", ") + (
"Arg" + std::to_string(
I))), ...);
98 if constexpr (!std::is_void_v<RetT>) {
101 OS << InstanceName <<
"(";
102 ((
OS << (
I == 0 ?
"" :
", ") <<
"Arg" << std::to_string(
I)), ...);
105 return std::string(
FuncS);
108 template <
typename RetOrSig,
typename...
ArgT> std::string buildCode() {
110 std::index_sequence_for<
ArgT...>{});
113 std::string_view
To) {
117 while ((
Pos =
S.find(
From,
Pos)) != std::string::npos) {
133 template <
typename RetT,
typename...
ArgT>
void compile() {
136 std::make_unique<CppJitModule>(TargetModel,
InstanceCode);
137 InstanceModule->compile();
139 FuncPtr = InstanceModule->Dispatch.getFunctionAddress(
140 EntryFuncName, InstanceModule->ModuleHash,
141 InstanceModule->getLibrary());
144 template <
typename...
ArgT>
147 if (!InstanceModule) {
151 void *
Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
153 return InstanceModule->Dispatch.launch(FuncPtr, GridDim, BlockDim,
Ptrs,
159 static_assert(!std::is_function_v<RetOrSig>,
160 "Function signature type is not yet supported");
162 if (!InstanceModule) {
166 if constexpr (std::is_void_v<RetOrSig>)
169 return InstanceModule->Dispatch.run<
RetOrSig(
ArgT...)>(FuncPtr,
174 struct CompilationResult {
176 std::unique_ptr<LLVMContext> Ctx =
nullptr;
177 std::unique_ptr<Module> Mod =
nullptr;
186 const std::vector<std::string> &ExtraArgs = {});
188 const std::vector<std::string> &ExtraArgs = {});
202 template <
typename...
ArgT>
205 std::string InstanceName =
FuncName.str() +
"<";
208 (
First ?
"" :
",") + std::string(std::forward<ArgT>(
Args)),
214 return CodeInstance{TargetModel, Code, InstanceName};
218 template <
typename RetT,
typename...
ArgT>
223 : M(M), FuncPtr(FuncPtr) {}
226 if constexpr (std::is_void_v<RetT>) {
227 M.Dispatch.template run<RetT(
ArgT...)>(FuncPtr,
228 std::forward<ArgT>(
Args)...);
230 return M.Dispatch.template run<RetT(
ArgT...)>(
231 FuncPtr, std::forward<ArgT>(
Args)...);
237 template <
typename RetT,
typename...
ArgT>
240 void *FuncPtr =
nullptr;
242 : M(M), FuncPtr(FuncPtr) {
243 static_assert(std::is_void_v<RetT>,
"Kernel function must return void");
248 void *
Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
char int void ** Args
Definition CompilerInterfaceHost.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition CppJitModule.hpp:11
FunctionHandle< Sig > getFunction(StringRef Name)
Definition CppJitModule.hpp:254
void compile()
Definition CppJitModule.cpp:226
void compileCppToDynamicLibrary()
Definition CppJitModule.cpp:42
KernelHandle< Sig > getKernel(StringRef Name)
Definition CppJitModule.hpp:266
auto instantiate(StringRef FuncName, ArgT... Args)
Definition CppJitModule.hpp:203
CompiledLibrary & getLibrary()
Definition CppJitModule.hpp:192
CompilationResult compileCppToIR()
Definition CppJitModule.cpp:125
Definition Dispatcher.hpp:54
virtual void * getFunctionAddress(StringRef FunctionName, HashT ModuleHash, CompiledLibrary &Library)=0
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)=0
Definition Hashing.hpp:20
Definition BuiltinsCUDA.cpp:4
TargetModelType
Definition TargetModel.hpp:14
static int Pos
Definition JitInterface.hpp:105
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.hpp:55
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
Definition Dispatcher.hpp:16
Definition CompiledLibrary.hpp:12
FunctionHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.hpp:222
void * FuncPtr
Definition CppJitModule.hpp:221
CppJitModule & M
Definition CppJitModule.hpp:220
RetT run(ArgT... Args)
Definition CppJitModule.hpp:225
Definition CppJitModule.hpp:217
auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition CppJitModule.hpp:246
KernelHandle(CppJitModule &M, void *FuncPtr)
Definition CppJitModule.hpp:241
CppJitModule & M
Definition CppJitModule.hpp:239
Definition CppJitModule.hpp:236