Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
JitFrontend.h
Go to the documentation of this file.
1#ifndef PROTEUS_JIT_DEV_H
2#define PROTEUS_JIT_DEV_H
3
4#include "proteus/Error.h"
9#include "proteus/Init.h"
10
11#include <deque>
12#include <type_traits>
13
14namespace proteus {
15
16struct CompiledLibrary;
17
18class JitModule {
19private:
20 std::unique_ptr<LLVMCodeBuilder> CB;
21 std::unique_ptr<CompiledLibrary> Library;
22
23 std::deque<std::unique_ptr<FuncBase>> Functions;
24 TargetModelType TargetModel;
25 Dispatcher &Dispatch;
26
27 std::unique_ptr<HashT> ModuleHash;
28 bool IsCompiled = false;
29
30 template <typename... ArgT> struct KernelHandle;
31
32 template <typename RetT, typename... ArgT>
33 Func<RetT, ArgT...> &buildFuncFromArgsList(const std::string &Name,
35 auto TypedFn =
36 std::make_unique<Func<RetT, ArgT...>>(*this, *CB, Name, Dispatch);
37 Func<RetT, ArgT...> &TypedFnRef = *TypedFn;
38 Functions.emplace_back(std::move(TypedFn));
39 TypedFnRef.declArgs();
40 return TypedFnRef;
41 }
42
43 template <typename... ArgT>
44 KernelHandle<ArgT...> buildKernelFromArgsList(const std::string &Name,
46 auto TypedFn =
47 std::make_unique<Func<void, ArgT...>>(*this, *CB, Name, Dispatch);
48 Func<void, ArgT...> &TypedFnRef = *TypedFn;
49 TypedFn->declArgs();
50
51#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
52 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(TypedFn));
53 Fn->setKernel();
54#else
55 reportFatalError("setKernel() is only supported for CUDA/HIP");
56#endif
57 return KernelHandle<ArgT...>{TypedFnRef, *this};
58 }
59
60 template <typename... ArgT> struct KernelHandle {
61 Func<void, ArgT...> &F;
62 JitModule &M;
63
64 void setLaunchBounds([[maybe_unused]] int MaxThreadsPerBlock,
65 [[maybe_unused]] int MinBlocksPerSM = 0) {
66 if (!M.isDeviceModule())
67 reportFatalError("Expected a device module for setLaunchBounds");
68
69 if (M.isCompiled())
70 reportFatalError("setLaunchBounds must be called before compile()");
71
72#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
73 F.setLaunchBoundsForKernel(MaxThreadsPerBlock, MinBlocksPerSM);
74#else
75 reportFatalError("Unsupported target for setLaunchBounds");
76#endif
77 }
78
79 // Launch with type-safety.
80 [[nodiscard]] auto launch(LaunchDims Grid, LaunchDims Block,
81 uint64_t ShmemBytes, void *Stream, ArgT... Args) {
82 // Pointers to the local parameter copies.
83 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
84
85 if (!M.isCompiled())
86 M.compile();
87
88 auto GetKernelFunc = [&]() {
89 // Get the kernel func pointer directly from the Func object if
90 // available.
91 if (auto KernelFunc = F.getCompiledFunc()) {
92 return KernelFunc;
93 }
94
95 // Get the kernel func pointer from the Dispatch and store it to the
96 // Func object to avoid cache lookups.
97 // TODO: Re-think caching and dispatchers.
98 auto KernelFunc = reinterpret_cast<decltype(F.getCompiledFunc())>(
99 M.Dispatch.getFunctionAddress(F.getName(), M.getModuleHash(),
100 M.getLibrary()));
101
102 F.setCompiledFunc(KernelFunc);
103
104 return KernelFunc;
105 };
106
107 return M.Dispatch.launch(reinterpret_cast<void *>(GetKernelFunc()), Grid,
108 Block, Ptrs, ShmemBytes, Stream);
109 }
110
111 FuncBase *operator->() { return &F; }
112 };
113
114 bool isDeviceModule() {
115 return ((TargetModel == TargetModelType::CUDA) ||
116 (TargetModel == TargetModelType::HIP));
117 }
118
119public:
120 JitModule(const std::string &Target = "host");
121
122 // Disable copy and move constructors.
123 JitModule(const JitModule &) = delete;
124 JitModule &operator=(const JitModule &) = delete;
125 JitModule(JitModule &&) = delete;
127
129
130 template <typename Sig> auto &addFunction(const std::string &Name) {
131 using RetT = typename FnSig<Sig>::RetT;
132 using ArgT = typename FnSig<Sig>::ArgsTList;
133
134 if (IsCompiled)
135 reportFatalError("The module is compiled, no further code can be added");
136
137 return buildFuncFromArgsList<RetT>(Name, ArgT{});
138 }
139
140 bool isCompiled() const { return IsCompiled; }
141
142 template <typename Sig> auto addKernel(const std::string &Name) {
143 using RetT = typename FnSig<Sig>::RetT;
144 static_assert(std::is_void_v<RetT>, "Kernels must have void return type");
145 using ArgT = typename FnSig<Sig>::ArgsTList;
146
147 if (IsCompiled)
148 reportFatalError("The module is compiled, no further code can be added");
149
150 if (!isDeviceModule())
151 reportFatalError("Expected a device module for addKernel");
152
153 return buildKernelFromArgsList(Name, ArgT{});
154 }
155
156 void compile(bool Verify = false);
157
158 const HashT &getModuleHash() const;
159
160 Dispatcher &getDispatcher() const { return Dispatch; }
161
162 TargetModelType getTargetModel() const { return TargetModel; }
163
165 if (!IsCompiled)
166 compile();
167
168 if (!Library)
169 reportFatalError("Expected non-null library after compilation");
170
171 return *Library;
172 }
173
174 void print();
175};
176
177template <typename RetT, typename... ArgT>
179 if (!J.isCompiled())
180 J.compile();
181
182 if (!CompiledFunc) {
183 CompiledFunc = reinterpret_cast<decltype(CompiledFunc)>(
184 J.getDispatcher().getFunctionAddress(getName(), J.getModuleHash(),
185 J.getLibrary()));
186 }
187
188 if (J.getTargetModel() != TargetModelType::HOST)
190 "Target is a GPU model, cannot directly run functions, use launch()");
191
192 if constexpr (std::is_void_v<RetT>)
193 Dispatch.run<RetT(ArgT...)>(reinterpret_cast<void *>(CompiledFunc),
194 Args...);
195 else
196 return Dispatch.run<RetT(ArgT...)>(reinterpret_cast<void *>(CompiledFunc),
197 Args...);
198}
199
200} // namespace proteus
201
202#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:22
Definition Dispatcher.h:74
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, void *KernelArgs[], uint64_t ShmemSize, void *Stream)=0
virtual void * getFunctionAddress(const std::string &FunctionName, const HashT &ModuleHash, CompiledLibrary &Library)=0
Definition Func.h:45
Definition Func.h:290
RetT operator()(ArgT... Args)
Definition JitFrontend.h:178
void declArgs()
Definition Func.h:328
Definition Hashing.h:21
Definition JitFrontend.h:18
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
bool isCompiled() const
Definition JitFrontend.h:140
TargetModelType getTargetModel() const
Definition JitFrontend.h:162
auto addKernel(const std::string &Name)
Definition JitFrontend.h:142
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
CompiledLibrary & getLibrary()
Definition JitFrontend.h:164
Dispatcher & getDispatcher() const
Definition JitFrontend.h:160
auto & addFunction(const std::string &Name)
Definition JitFrontend.h:130
const HashT & getModuleHash() const
Definition JitFrontend.cpp:61
void print()
Definition JitFrontend.cpp:57
void compile(bool Verify=false)
Definition JitFrontend.cpp:23
Definition MemoryCache.h:26
TargetModelType
Definition TargetModel.h:8
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:21
Definition Func.h:24
Definition CompiledLibrary.h:18
Definition Func.h:25