1#ifndef PROTEUS_MLIRJITMODULE_H
2#define PROTEUS_MLIRJITMODULE_H
15struct CompiledLibrary;
23 std::unique_ptr<HashT> ModuleHash;
24 std::unique_ptr<CompiledLibrary> Library;
25 bool IsCompiled =
false;
28 void *getFunctionAddress(
const std::string &Name);
31 uint64_t ShmemSize,
void *Stream);
35 explicit MLIRJitModule(
const std::string &Target,
const std::string &Code);
38 void compile(
bool Verify =
false);
51 template <
typename RetT,
typename... ArgT>
57 : M(M), FuncPtr(FuncPtr) {}
60 if constexpr (std::is_void_v<RetT>) {
61 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
62 std::forward<ArgT>(
Args)...);
64 return M.Dispatch.template run<RetT(ArgT...)>(
65 FuncPtr, std::forward<ArgT>(
Args)...);
71 template <
typename RetT,
typename... ArgT>
74 void *FuncPtr =
nullptr;
77 : M(M), FuncPtr(FuncPtr) {
78 static_assert(std::is_void_v<RetT>,
"Kernel function must return void");
82 M.setFuncAttribute(FuncPtr, Attr, Value);
86 void *Stream, ArgT...
Args) {
87 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
88 return M.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize, Stream);
92 template <
typename Sig>
100 void *FuncPtr = getFunctionAddress(Name);
111 void *FuncPtr = getFunctionAddress(Name);
char int void ** Args
Definition CompilerInterfaceHost.cpp:23
Definition Dispatcher.h:75
Definition MLIRJitModule.h:18
KernelHandle< Sig > getKernel(const std::string &Name)
Definition MLIRJitModule.h:104
void compile(bool Verify=false)
Definition MLIRJitModule.cpp:34
FunctionHandle< Sig > getFunction(const std::string &Name)
Definition MLIRJitModule.h:93
CompiledLibrary & getLibrary()
Definition MLIRJitModule.h:40
Definition MemoryCache.h:27
JitFuncAttribute
Definition JitFuncAttribute.h:6
TargetModelType
Definition TargetModel.h:8
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.cpp:49
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:22
Definition CompiledLibrary.h:18
Definition Dispatcher.h:53
RetT run(ArgT... Args)
Definition MLIRJitModule.h:59
void * FuncPtr
Definition MLIRJitModule.h:54
MLIRJitModule & M
Definition MLIRJitModule.h:53
FunctionHandle(MLIRJitModule &M, void *FuncPtr)
Definition MLIRJitModule.h:56
Definition MLIRJitModule.h:50
auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition MLIRJitModule.h:85
MLIRJitModule & M
Definition MLIRJitModule.h:73
void setFuncAttribute(JitFuncAttribute Attr, int Value)
Definition MLIRJitModule.h:81
KernelHandle(MLIRJitModule &M, void *FuncPtr)
Definition MLIRJitModule.h:76
Definition MLIRJitModule.h:70