Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
MLIRJitModule.h
Go to the documentation of this file.
1#ifndef PROTEUS_MLIRJITMODULE_H
2#define PROTEUS_MLIRJITMODULE_H
3
7
8#include <memory>
9#include <string>
10#include <type_traits>
11#include <utility>
12
13namespace proteus {
14
15struct CompiledLibrary;
16class HashT;
17
19private:
20 TargetModelType TargetModel;
21 std::string Code;
22 Dispatcher &Dispatch;
23 std::unique_ptr<HashT> ModuleHash;
24 std::unique_ptr<CompiledLibrary> Library;
25 bool IsCompiled = false;
26
27 void setFuncAttribute(void *KernelFunc, JitFuncAttribute Attr, int Value);
28 void *getFunctionAddress(const std::string &Name);
29 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
30 LaunchDims BlockDim, void *KernelArgs[],
31 uint64_t ShmemSize, void *Stream);
32
33public:
34 explicit MLIRJitModule(TargetModelType TargetModel, const std::string &Code);
35 explicit MLIRJitModule(const std::string &Target, const std::string &Code);
37
38 void compile(bool Verify = false);
39
41 if (!IsCompiled)
42 compile();
43
44 if (!Library)
45 reportFatalError("Expected non-null library after compilation");
46
47 return *Library;
48 }
49
50 template <typename Sig> struct FunctionHandle;
51 template <typename RetT, typename... ArgT>
52 struct FunctionHandle<RetT(ArgT...)> {
54 void *FuncPtr;
55
56 explicit FunctionHandle(MLIRJitModule &M, void *FuncPtr)
57 : M(M), FuncPtr(FuncPtr) {}
58
59 RetT run(ArgT... Args) {
60 if constexpr (std::is_void_v<RetT>) {
61 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
62 std::forward<ArgT>(Args)...);
63 } else {
64 return M.Dispatch.template run<RetT(ArgT...)>(
65 FuncPtr, std::forward<ArgT>(Args)...);
66 }
67 }
68 };
69
70 template <typename Sig> struct KernelHandle;
71 template <typename RetT, typename... ArgT>
72 struct KernelHandle<RetT(ArgT...)> {
74 void *FuncPtr = nullptr;
75
76 explicit KernelHandle(MLIRJitModule &M, void *FuncPtr)
77 : M(M), FuncPtr(FuncPtr) {
78 static_assert(std::is_void_v<RetT>, "Kernel function must return void");
79 }
80
81 void setFuncAttribute(JitFuncAttribute Attr, int Value) {
82 M.setFuncAttribute(FuncPtr, Attr, Value);
83 }
84
85 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
86 void *Stream, ArgT... Args) {
87 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
88 return M.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize, Stream);
89 }
90 };
91
92 template <typename Sig>
93 FunctionHandle<Sig> getFunction(const std::string &Name) {
94 if (!IsCompiled)
95 compile();
96
97 if (!isHostTargetModel(TargetModel))
98 reportFatalError("Error: getFunction() applies only to host modules");
99
100 void *FuncPtr = getFunctionAddress(Name);
101 return FunctionHandle<Sig>(*this, FuncPtr);
102 }
103
104 template <typename Sig> KernelHandle<Sig> getKernel(const std::string &Name) {
105 if (!IsCompiled)
106 compile();
107
108 if (TargetModel == TargetModelType::HOST)
109 reportFatalError("Error: getKernel() applies only to device modules");
110
111 void *FuncPtr = getFunctionAddress(Name);
112 return KernelHandle<Sig>(*this, FuncPtr);
113 }
114};
115
116} // namespace proteus
117
118#endif // PROTEUS_MLIRJITMODULE_H
char int void ** Args
Definition CompilerInterfaceHost.cpp:23
Definition Dispatcher.h:75
Definition MLIRJitModule.h:18
KernelHandle< Sig > getKernel(const std::string &Name)
Definition MLIRJitModule.h:104
void compile(bool Verify=false)
Definition MLIRJitModule.cpp:34
FunctionHandle< Sig > getFunction(const std::string &Name)
Definition MLIRJitModule.h:93
CompiledLibrary & getLibrary()
Definition MLIRJitModule.h:40
Definition MemoryCache.h:27
JitFuncAttribute
Definition JitFuncAttribute.h:6
TargetModelType
Definition TargetModel.h:8
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.cpp:49
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:22
Definition CompiledLibrary.h:18
Definition Dispatcher.h:53
RetT run(ArgT... Args)
Definition MLIRJitModule.h:59
void * FuncPtr
Definition MLIRJitModule.h:54
MLIRJitModule & M
Definition MLIRJitModule.h:53
FunctionHandle(MLIRJitModule &M, void *FuncPtr)
Definition MLIRJitModule.h:56
Definition MLIRJitModule.h:50
auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition MLIRJitModule.h:85
MLIRJitModule & M
Definition MLIRJitModule.h:73
void setFuncAttribute(JitFuncAttribute Attr, int Value)
Definition MLIRJitModule.h:81
KernelHandle(MLIRJitModule &M, void *FuncPtr)
Definition MLIRJitModule.h:76
Definition MLIRJitModule.h:70