Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitFrontend.h
Go to the documentation of this file.
1#ifndef PROTEUS_JIT_DEV_H
2#define PROTEUS_JIT_DEV_H
3
4#include "proteus/Error.h"
10#include "proteus/Init.h"
11
12#include <deque>
13#include <type_traits>
14
15namespace proteus {
16
17struct CompiledLibrary;
18
19class JitModule {
20private:
21 std::unique_ptr<CodeBuilder> CB;
22 std::unique_ptr<CompiledLibrary> Library;
23
24 std::deque<std::unique_ptr<FuncBase>> Functions;
25 TargetModelType TargetModel;
26 Dispatcher &Dispatch;
27
28 std::unique_ptr<HashT> ModuleHash;
29 bool IsCompiled = false;
30
31 template <typename... ArgT> struct KernelHandle;
32
33 template <typename RetT, typename... ArgT>
34 Func<RetT, ArgT...> &buildFuncFromArgsList(const std::string &Name,
36 auto TypedFn =
37 std::make_unique<Func<RetT, ArgT...>>(*this, *CB, Name, Dispatch,
38 /*IsKernel=*/false);
39 Func<RetT, ArgT...> &TypedFnRef = *TypedFn;
40 Functions.emplace_back(std::move(TypedFn));
41 TypedFnRef.declArgs();
42 return TypedFnRef;
43 }
44
45 template <typename... ArgT>
46 KernelHandle<ArgT...> buildKernelFromArgsList(const std::string &Name,
48 auto TypedFn =
49 std::make_unique<Func<void, ArgT...>>(*this, *CB, Name, Dispatch,
50 /*IsKernel=*/true);
51 Func<void, ArgT...> &TypedFnRef = *TypedFn;
52 TypedFn->declArgs();
53
54 Functions.emplace_back(std::move(TypedFn));
55 return KernelHandle<ArgT...>{TypedFnRef, *this};
56 }
57
58 template <typename... ArgT> struct KernelHandle {
59 Func<void, ArgT...> &F;
60 JitModule &M;
61
62 void setLaunchBounds([[maybe_unused]] int MaxThreadsPerBlock,
63 [[maybe_unused]] int MinBlocksPerSM = 0) {
64 if (!M.isDeviceModule())
65 reportFatalError("Expected a device module for setLaunchBounds");
66
67 if (M.isCompiled())
68 reportFatalError("setLaunchBounds must be called before compile()");
69
70#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
71 F.setLaunchBoundsForKernel(MaxThreadsPerBlock, MinBlocksPerSM);
72#else
73 reportFatalError("Unsupported target for setLaunchBounds");
74#endif
75 }
76
77 // Launch with type-safety.
78 [[nodiscard]] auto launch(LaunchDims Grid, LaunchDims Block,
79 uint64_t ShmemBytes, void *Stream, ArgT... Args) {
80 // Pointers to the local parameter copies.
81 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
82
83 if (!M.isCompiled())
84 M.compile();
85
86 auto GetKernelFunc = [&]() {
87 // Get the kernel func pointer directly from the Func object if
88 // available.
89 if (auto KernelFunc = F.getCompiledFunc()) {
90 return KernelFunc;
91 }
92
93 // Get the kernel func pointer from the Dispatch and store it to the
94 // Func object to avoid cache lookups.
95 // TODO: Re-think caching and dispatchers.
96 auto KernelFunc = reinterpret_cast<decltype(F.getCompiledFunc())>(
97 M.Dispatch.getFunctionAddress(F.getName(), M.getModuleHash(),
98 M.getLibrary()));
99
100 F.setCompiledFunc(KernelFunc);
101
102 return KernelFunc;
103 };
104
105 return M.Dispatch.launch(reinterpret_cast<void *>(GetKernelFunc()), Grid,
106 Block, Ptrs, ShmemBytes, Stream);
107 }
108
109 FuncBase *operator->() { return &F; }
110 };
111
112 bool isDeviceModule() {
113 return ((TargetModel == TargetModelType::CUDA) ||
114 (TargetModel == TargetModelType::HIP));
115 }
116
117public:
118 JitModule(const std::string &Target = "host",
119 const std::string &Backend = "llvm");
120
121 // Disable copy and move constructors.
122 JitModule(const JitModule &) = delete;
123 JitModule &operator=(const JitModule &) = delete;
124 JitModule(JitModule &&) = delete;
126
128
129 template <typename Sig> auto &addFunction(const std::string &Name) {
130 using RetT = typename FnSig<Sig>::RetT;
131 using ArgT = typename FnSig<Sig>::ArgsTList;
132
133 if (IsCompiled)
134 reportFatalError("The module is compiled, no further code can be added");
135
136 return buildFuncFromArgsList<RetT>(Name, ArgT{});
137 }
138
139 bool isCompiled() const { return IsCompiled; }
140
141 template <typename Sig> auto addKernel(const std::string &Name) {
142 using RetT = typename FnSig<Sig>::RetT;
143 static_assert(std::is_void_v<RetT>, "Kernels must have void return type");
144 using ArgT = typename FnSig<Sig>::ArgsTList;
145
146 if (IsCompiled)
147 reportFatalError("The module is compiled, no further code can be added");
148
149 if (!isDeviceModule())
150 reportFatalError("Expected a device module for addKernel");
151
152 return buildKernelFromArgsList(Name, ArgT{});
153 }
154
155 void compile(bool Verify = false);
156
157 const HashT &getModuleHash() const;
158
159 Dispatcher &getDispatcher() const { return Dispatch; }
160
161 TargetModelType getTargetModel() const { return TargetModel; }
162
164 if (!IsCompiled)
165 compile();
166
167 if (!Library)
168 reportFatalError("Expected non-null library after compilation");
169
170 return *Library;
171 }
172
173 void print();
174 void printLLVMIR();
175};
176
177template <typename RetT, typename... ArgT>
179 if (!J.isCompiled())
180 J.compile();
181
182 if (!CompiledFunc) {
183 CompiledFunc = reinterpret_cast<decltype(CompiledFunc)>(
184 J.getDispatcher().getFunctionAddress(getName(), J.getModuleHash(),
185 J.getLibrary()));
186 }
187
188 if (J.getTargetModel() != TargetModelType::HOST)
190 "Target is a GPU model, cannot directly run functions, use launch()");
191
192 if constexpr (std::is_void_v<RetT>)
193 Dispatch.run<RetT(ArgT...)>(reinterpret_cast<void *>(CompiledFunc),
194 Args...);
195 else
196 return Dispatch.run<RetT(ArgT...)>(reinterpret_cast<void *>(CompiledFunc),
197 Args...);
198}
199
200} // namespace proteus
201
202#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:23
Definition Dispatcher.h:75
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, void *KernelArgs[], uint64_t ShmemSize, void *Stream)=0
virtual void * getFunctionAddress(const std::string &FunctionName, const HashT &ModuleHash, CompiledLibrary &Library)=0
Definition Func.h:45
Definition Func.h:296
RetT operator()(ArgT... Args)
Definition JitFrontend.h:178
void declArgs()
Definition Func.h:335
Definition Hashing.h:22
Definition JitFrontend.h:19
void printLLVMIR()
Definition JitFrontend.cpp:141
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
bool isCompiled() const
Definition JitFrontend.h:139
TargetModelType getTargetModel() const
Definition JitFrontend.h:161
auto addKernel(const std::string &Name)
Definition JitFrontend.h:141
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
CompiledLibrary & getLibrary()
Definition JitFrontend.h:163
Dispatcher & getDispatcher() const
Definition JitFrontend.h:159
auto & addFunction(const std::string &Name)
Definition JitFrontend.h:129
const HashT & getModuleHash() const
Definition JitFrontend.cpp:163
void print()
Definition JitFrontend.cpp:124
void compile(bool Verify=false)
Definition JitFrontend.cpp:36
Definition MemoryCache.h:27
TargetModelType
Definition TargetModel.h:8
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:22
Definition Func.h:24
Definition CompiledLibrary.h:18
Definition Func.h:25