Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
MLIRJitModule.h
Go to the documentation of this file.
1#ifndef PROTEUS_MLIRJITMODULE_H
2#define PROTEUS_MLIRJITMODULE_H
3
7
8#include <memory>
9#include <string>
10#include <type_traits>
11#include <utility>
12
13namespace proteus {
14
15struct CompiledLibrary;
16class HashT;
17
19private:
20 TargetModelType TargetModel;
21 std::string Code;
22 Dispatcher &Dispatch;
23 std::unique_ptr<HashT> ModuleHash;
24 std::unique_ptr<CompiledLibrary> Library;
25 bool IsCompiled = false;
26
27public:
28 explicit MLIRJitModule(TargetModelType TargetModel, const std::string &Code);
29 explicit MLIRJitModule(const std::string &Target, const std::string &Code);
31
32 void compile(bool Verify = false);
33
34 // Expose the target model so higher-level bindings can reject invalid API
35 // combinations before dispatch.
36 TargetModelType getTargetModel() const { return TargetModel; }
37
38 void setFuncAttribute(void *KernelFunc, JitFuncAttribute Attr, int Value);
39 void *getFunctionAddress(const std::string &Name);
40 void *getKernelAddress(const std::string &Name) {
41 if (!IsCompiled)
42 compile();
43
44 if (TargetModel == TargetModelType::HOST)
46 "Error: getKernelAddress() applies only to device modules");
47
48 // Kernel symbols are loaded through the same compiled image as host
49 // function symbols once target validation is complete.
50 return getFunctionAddress(Name);
51 }
52
53 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
54 LaunchDims BlockDim, void *KernelArgs[],
55 uint64_t ShmemSize, void *Stream);
56
58 if (!IsCompiled)
59 compile();
60
61 if (!Library)
62 reportFatalError("Expected non-null library after compilation");
63
64 return *Library;
65 }
66
67 template <typename Sig> struct FunctionHandle;
68 template <typename RetT, typename... ArgT>
69 struct FunctionHandle<RetT(ArgT...)> {
71 void *FuncPtr;
72
73 explicit FunctionHandle(MLIRJitModule &M, void *FuncPtr)
74 : M(M), FuncPtr(FuncPtr) {}
75
76 RetT run(ArgT... Args) {
77 if constexpr (std::is_void_v<RetT>) {
78 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
79 std::forward<ArgT>(Args)...);
80 } else {
81 return M.Dispatch.template run<RetT(ArgT...)>(
82 FuncPtr, std::forward<ArgT>(Args)...);
83 }
84 }
85 };
86
87 template <typename Sig> struct KernelHandle;
88 template <typename RetT, typename... ArgT>
89 struct KernelHandle<RetT(ArgT...)> {
91 void *FuncPtr = nullptr;
92
93 explicit KernelHandle(MLIRJitModule &M, void *FuncPtr)
94 : M(M), FuncPtr(FuncPtr) {
95 static_assert(std::is_void_v<RetT>, "Kernel function must return void");
96 }
97
98 void setFuncAttribute(JitFuncAttribute Attr, int Value) {
99 M.setFuncAttribute(FuncPtr, Attr, Value);
100 }
101
102 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
103 void *Stream, ArgT... Args) {
104 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
105 return M.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize, Stream);
106 }
107 };
108
109 template <typename Sig>
110 FunctionHandle<Sig> getFunction(const std::string &Name) {
111 if (!IsCompiled)
112 compile();
113
114 if (!isHostTargetModel(TargetModel))
115 reportFatalError("Error: getFunction() applies only to host modules");
116
117 void *FuncPtr = getFunctionAddress(Name);
118 return FunctionHandle<Sig>(*this, FuncPtr);
119 }
120
121 template <typename Sig> KernelHandle<Sig> getKernel(const std::string &Name) {
122 if (!IsCompiled)
123 compile();
124
125 if (TargetModel == TargetModelType::HOST)
126 reportFatalError("Error: getKernel() applies only to device modules");
127
128 void *FuncPtr = getFunctionAddress(Name);
129 return KernelHandle<Sig>(*this, FuncPtr);
130 }
131};
132
133} // namespace proteus
134
135#endif // PROTEUS_MLIRJITMODULE_H
char int void ** Args
Definition CompilerInterfaceHost.cpp:23
Definition Dispatcher.h:75
Definition MLIRJitModule.h:18
void setFuncAttribute(void *KernelFunc, JitFuncAttribute Attr, int Value)
Definition MLIRJitModule.cpp:71
KernelHandle< Sig > getKernel(const std::string &Name)
Definition MLIRJitModule.h:121
void compile(bool Verify=false)
Definition MLIRJitModule.cpp:34
FunctionHandle< Sig > getFunction(const std::string &Name)
Definition MLIRJitModule.h:110
DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, void *KernelArgs[], uint64_t ShmemSize, void *Stream)
Definition MLIRJitModule.cpp:83
CompiledLibrary & getLibrary()
Definition MLIRJitModule.h:57
TargetModelType getTargetModel() const
Definition MLIRJitModule.h:36
void * getKernelAddress(const std::string &Name)
Definition MLIRJitModule.h:40
void * getFunctionAddress(const std::string &Name)
Definition MLIRJitModule.cpp:76
Definition MemoryCache.h:27
JitFuncAttribute
Definition JitFuncAttribute.h:6
TargetModelType
Definition TargetModel.h:8
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.cpp:49
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:22
Definition CompiledLibrary.h:18
Definition Dispatcher.h:53
RetT run(ArgT... Args)
Definition MLIRJitModule.h:76
void * FuncPtr
Definition MLIRJitModule.h:71
MLIRJitModule & M
Definition MLIRJitModule.h:70
FunctionHandle(MLIRJitModule &M, void *FuncPtr)
Definition MLIRJitModule.h:73
Definition MLIRJitModule.h:67
auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition MLIRJitModule.h:102
MLIRJitModule & M
Definition MLIRJitModule.h:90
void setFuncAttribute(JitFuncAttribute Attr, int Value)
Definition MLIRJitModule.h:98
KernelHandle(MLIRJitModule &M, void *FuncPtr)
Definition MLIRJitModule.h:93
Definition MLIRJitModule.h:87