1#ifndef PROTEUS_JIT_DEV_H
2#define PROTEUS_JIT_DEV_H
16struct CompiledLibrary;
20 std::unique_ptr<LLVMCodeBuilder> CB;
21 std::unique_ptr<CompiledLibrary> Library;
23 std::deque<std::unique_ptr<FuncBase>> Functions;
27 std::unique_ptr<HashT> ModuleHash;
28 bool IsCompiled =
false;
30 template <
typename... ArgT>
struct KernelHandle;
32 template <
typename RetT,
typename... ArgT>
33 Func<RetT, ArgT...> &buildFuncFromArgsList(
const std::string &Name,
36 std::make_unique<
Func<RetT, ArgT...>>(*
this, *CB, Name, Dispatch);
37 Func<RetT, ArgT...> &TypedFnRef = *TypedFn;
38 Functions.emplace_back(std::move(TypedFn));
43 template <
typename... ArgT>
44 KernelHandle<ArgT...> buildKernelFromArgsList(
const std::string &Name,
47 std::make_unique<
Func<void, ArgT...>>(*
this, *CB, Name, Dispatch);
48 Func<void, ArgT...> &TypedFnRef = *TypedFn;
51#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
52 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(TypedFn));
57 return KernelHandle<ArgT...>{TypedFnRef, *
this};
60 template <
typename... ArgT>
struct KernelHandle {
61 Func<void, ArgT...> &F;
64 void setLaunchBounds([[maybe_unused]]
int MaxThreadsPerBlock,
65 [[maybe_unused]]
int MinBlocksPerSM = 0) {
66 if (!M.isDeviceModule())
72#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
73 F.setLaunchBoundsForKernel(MaxThreadsPerBlock, MinBlocksPerSM);
81 uint64_t ShmemBytes,
void *Stream, ArgT...
Args) {
83 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
88 auto GetKernelFunc = [&]() {
91 if (
auto KernelFunc = F.getCompiledFunc()) {
98 auto KernelFunc =
reinterpret_cast<decltype(F.getCompiledFunc())
>(
102 F.setCompiledFunc(KernelFunc);
107 return M.Dispatch.
launch(
reinterpret_cast<void *
>(GetKernelFunc()), Grid,
108 Block, Ptrs, ShmemBytes, Stream);
111 FuncBase *operator->() {
return &F; }
114 bool isDeviceModule() {
120 JitModule(
const std::string &Target =
"host");
130 template <
typename Sig>
auto &
addFunction(
const std::string &Name) {
137 return buildFuncFromArgsList<RetT>(Name, ArgT{});
142 template <
typename Sig>
auto addKernel(
const std::string &Name) {
144 static_assert(std::is_void_v<RetT>,
"Kernels must have void return type");
150 if (!isDeviceModule())
153 return buildKernelFromArgsList(Name, ArgT{});
156 void compile(
bool Verify =
false);
177template <
typename RetT,
typename... ArgT>
183 CompiledFunc =
reinterpret_cast<decltype(CompiledFunc)
>(
184 J.getDispatcher().getFunctionAddress(getName(), J.getModuleHash(),
190 "Target is a GPU model, cannot directly run functions, use launch()");
192 if constexpr (std::is_void_v<RetT>)
193 Dispatch.run<RetT(ArgT...)>(
reinterpret_cast<void *
>(CompiledFunc),
196 return Dispatch.run<RetT(ArgT...)>(
reinterpret_cast<void *
>(CompiledFunc),
char int void ** Args
Definition CompilerInterfaceHost.cpp:22
Definition Dispatcher.h:74
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, void *KernelArgs[], uint64_t ShmemSize, void *Stream)=0
virtual void * getFunctionAddress(const std::string &FunctionName, const HashT &ModuleHash, CompiledLibrary &Library)=0
RetT operator()(ArgT... Args)
Definition JitFrontend.h:178
void declArgs()
Definition Func.h:328
Definition JitFrontend.h:18
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
bool isCompiled() const
Definition JitFrontend.h:140
TargetModelType getTargetModel() const
Definition JitFrontend.h:162
auto addKernel(const std::string &Name)
Definition JitFrontend.h:142
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
CompiledLibrary & getLibrary()
Definition JitFrontend.h:164
Dispatcher & getDispatcher() const
Definition JitFrontend.h:160
auto & addFunction(const std::string &Name)
Definition JitFrontend.h:130
const HashT & getModuleHash() const
Definition JitFrontend.cpp:61
void print()
Definition JitFrontend.cpp:57
void compile(bool Verify=false)
Definition JitFrontend.cpp:23
Definition MemoryCache.h:26
TargetModelType
Definition TargetModel.h:8
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:21
Definition CompiledLibrary.h:18