1#ifndef PROTEUS_JIT_DEV_H
2#define PROTEUS_JIT_DEV_H
17struct CompiledLibrary;
21 std::unique_ptr<CodeBuilder> CB;
22 std::unique_ptr<CompiledLibrary> Library;
24 std::deque<std::unique_ptr<FuncBase>> Functions;
28 std::unique_ptr<HashT> ModuleHash;
29 bool IsCompiled =
false;
31 template <
typename... ArgT>
struct KernelHandle;
33 template <
typename RetT,
typename... ArgT>
34 Func<RetT, ArgT...> &buildFuncFromArgsList(
const std::string &Name,
37 std::make_unique<
Func<RetT, ArgT...>>(*
this, *CB, Name, Dispatch,
39 Func<RetT, ArgT...> &TypedFnRef = *TypedFn;
40 Functions.emplace_back(std::move(TypedFn));
45 template <
typename... ArgT>
46 KernelHandle<ArgT...> buildKernelFromArgsList(
const std::string &Name,
49 std::make_unique<
Func<void, ArgT...>>(*
this, *CB, Name, Dispatch,
51 Func<void, ArgT...> &TypedFnRef = *TypedFn;
54 Functions.emplace_back(std::move(TypedFn));
55 return KernelHandle<ArgT...>{TypedFnRef, *
this};
58 template <
typename... ArgT>
struct KernelHandle {
59 Func<void, ArgT...> &F;
62 void setLaunchBounds([[maybe_unused]]
int MaxThreadsPerBlock,
63 [[maybe_unused]]
int MinBlocksPerSM = 0) {
64 if (!M.isDeviceModule())
70#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
71 F.setLaunchBoundsForKernel(MaxThreadsPerBlock, MinBlocksPerSM);
79 uint64_t ShmemBytes,
void *Stream, ArgT...
Args) {
81 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
86 auto GetKernelFunc = [&]() {
89 if (
auto KernelFunc = F.getCompiledFunc()) {
96 auto KernelFunc =
reinterpret_cast<decltype(F.getCompiledFunc())
>(
100 F.setCompiledFunc(KernelFunc);
105 return M.Dispatch.
launch(
reinterpret_cast<void *
>(GetKernelFunc()), Grid,
106 Block, Ptrs, ShmemBytes, Stream);
109 FuncBase *operator->() {
return &F; }
112 bool isDeviceModule() {
118 JitModule(
const std::string &Target =
"host",
119 const std::string &Backend =
"llvm");
129 template <
typename Sig>
auto &
addFunction(
const std::string &Name) {
136 return buildFuncFromArgsList<RetT>(Name, ArgT{});
141 template <
typename Sig>
auto addKernel(
const std::string &Name) {
143 static_assert(std::is_void_v<RetT>,
"Kernels must have void return type");
149 if (!isDeviceModule())
152 return buildKernelFromArgsList(Name, ArgT{});
155 void compile(
bool Verify =
false);
177template <
typename RetT,
typename... ArgT>
183 CompiledFunc =
reinterpret_cast<decltype(CompiledFunc)
>(
184 J.getDispatcher().getFunctionAddress(getName(), J.getModuleHash(),
190 "Target is a GPU model, cannot directly run functions, use launch()");
192 if constexpr (std::is_void_v<RetT>)
193 Dispatch.run<RetT(ArgT...)>(
reinterpret_cast<void *
>(CompiledFunc),
196 return Dispatch.run<RetT(ArgT...)>(
reinterpret_cast<void *
>(CompiledFunc),
char int void ** Args
Definition CompilerInterfaceHost.cpp:23
Definition Dispatcher.h:75
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, void *KernelArgs[], uint64_t ShmemSize, void *Stream)=0
virtual void * getFunctionAddress(const std::string &FunctionName, const HashT &ModuleHash, CompiledLibrary &Library)=0
RetT operator()(ArgT... Args)
Definition JitFrontend.h:178
void declArgs()
Definition Func.h:335
Definition JitFrontend.h:19
void printLLVMIR()
Definition JitFrontend.cpp:141
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
bool isCompiled() const
Definition JitFrontend.h:139
TargetModelType getTargetModel() const
Definition JitFrontend.h:161
auto addKernel(const std::string &Name)
Definition JitFrontend.h:141
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
CompiledLibrary & getLibrary()
Definition JitFrontend.h:163
Dispatcher & getDispatcher() const
Definition JitFrontend.h:159
auto & addFunction(const std::string &Name)
Definition JitFrontend.h:129
const HashT & getModuleHash() const
Definition JitFrontend.cpp:163
void print()
Definition JitFrontend.cpp:124
void compile(bool Verify=false)
Definition JitFrontend.cpp:36
Definition MemoryCache.h:27
TargetModelType
Definition TargetModel.h:8
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:22
Definition CompiledLibrary.h:18