23 std::unique_ptr<HashT> ModuleHash;
24 std::vector<std::string> ExtraArgs;
28 std::unique_ptr<CompiledLibrary> Library;
29 bool IsCompiled =
false;
31 struct FuncAttributeState {
58 const std::string &TemplateCode;
59 std::vector<std::string> ExtraArgs;
61 std::string InstanceName;
62 std::unique_ptr<CppJitModule> InstanceModule;
63 std::string EntryFuncName;
64 void *FuncPtr =
nullptr;
65 FuncAttributeState FuncAttributes;
67 CodeInstance(
TargetModelType TargetModel,
const std::string &TemplateCode,
68 const std::vector<std::string> &ExtraArgs,
70 const std::string &InstanceName,
71 FuncAttributeState FuncAttributes = {})
72 : TargetModel(TargetModel), TemplateCode(TemplateCode),
73 ExtraArgs(ExtraArgs), CompilerBackend(CompilerBackend),
74 InstanceName(InstanceName),
75 FuncAttributes(std::move(FuncAttributes)) {
76 EntryFuncName =
"__proteus_instance_" + this->InstanceName;
80 EntryFuncName.begin(), EntryFuncName.end(),
81 [](
char C) { return C ==
'<' || C ==
'>' || C ==
','; },
'$');
84 bool useHostLaunchEntry()
const {
103 return CompilerBackend;
106 const char *getGPUStreamType()
const {
108 return "cudaStream_t";
112 return "hipStream_t";
118 const char *getGPUGetLastErrorName()
const {
120 return "cudaGetLastError";
124 return "hipGetLastError";
130 const char *getGPUFuncSetAttributeName()
const {
132 return "cudaFuncSetAttribute";
136 return "hipFuncSetAttribute";
142 const char *getGPUMaxDynamicSharedMemoryAttrName()
const {
144 return "cudaFuncAttributeMaxDynamicSharedMemorySize";
148 return "hipFuncAttributeMaxDynamicSharedMemorySize";
155 template <
class T>
constexpr std::string_view typeName() {
158#if defined(__clang__)
160 std::string_view P = __PRETTY_FUNCTION__;
161 auto B = P.find(
"[T = ") + 5;
162 auto E = P.rfind(
']');
163 return P.substr(B, E - B);
164#elif defined(__GNUC__)
166 std::string_view P = __PRETTY_FUNCTION__;
167 auto B = P.find(
"with T = ") + 9;
168 auto E = P.find(
';', B);
169 return P.substr(B, E - B);
170#elif defined(_MSC_VER)
172 std::string_view P = __FUNCSIG__;
173 auto B = P.find(
"type_name<") + 10;
174 auto E = P.find(
">(void)", B);
175 return P.substr(B, B - E);
181 template <
typename RetT,
typename... ArgT, std::size_t... I>
182 std::string buildFunctionEntry(std::index_sequence<I...>) {
183 std::stringstream OS;
185 OS <<
"extern \"C\" " << typeName<RetT>() <<
" "
187 << EntryFuncName <<
"(";
188 ((OS << (I ?
", " :
"")
190 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
196 ((ArgList += (I == 0 ?
"" :
", ") + (
"Arg" + std::to_string(I))), ...);
198 if constexpr (!std::is_void_v<RetT>) {
201 OS << InstanceName <<
"(";
202 ((OS << (I == 0 ?
"" :
", ") <<
"Arg" << std::to_string(I)), ...);
208 template <
typename... ArgT, std::size_t... I>
209 std::string buildGPUHostLauncher(std::index_sequence<I...>) {
210 std::stringstream OS;
212 OS <<
"extern \"C\" int " << EntryFuncName
213 <<
"(unsigned GridX, unsigned GridY, unsigned GridZ, unsigned "
214 "BlockX, unsigned BlockY, unsigned BlockZ, size_t ShmemSize, "
218 std::decay_t<std::tuple_element_t<I, std::tuple<ArgT...>>>>()
222 if (FuncAttributes.MaxDynamicSharedMemorySize) {
223 OS <<
"{ auto AttrErr = " << getGPUFuncSetAttributeName()
224 <<
"((const void *)" << InstanceName <<
", "
225 << getGPUMaxDynamicSharedMemoryAttrName() <<
", "
226 << *FuncAttributes.MaxDynamicSharedMemorySize
227 <<
"); if (AttrErr != 0) return static_cast<int>(AttrErr); } ";
232 <<
"<<<dim3(GridX, GridY, GridZ), dim3(BlockX, BlockY, BlockZ), "
233 "ShmemSize, static_cast<"
234 << getGPUStreamType() <<
">(Stream)>>>(";
235 ((OS << (I == 0 ?
"" :
", ") <<
"Arg" << I), ...);
236 OS <<
"); return static_cast<int>(" << getGPUGetLastErrorName()
242 template <
typename RetOrSig,
typename... ArgT> std::string buildCode() {
243 if constexpr (std::is_void_v<RetOrSig>) {
244 if (useHostLaunchEntry()) {
245 return TemplateCode + buildGPUHostLauncher<ArgT...>(
246 std::index_sequence_for<ArgT...>{});
250 std::string FunctionCode = buildFunctionEntry<RetOrSig, ArgT...>(
251 std::index_sequence_for<ArgT...>{});
253 auto ReplaceAll = [](std::string &S, std::string_view From,
254 std::string_view To) {
258 while ((
Pos = S.find(From,
Pos)) != std::string::npos) {
259 S.replace(
Pos, From.size(), To);
265 std::string InstanceCode = TemplateCode;
268 ReplaceAll(InstanceCode,
"__global__",
"__device__");
269 InstanceCode = InstanceCode + FunctionCode;
274 template <
typename RetT,
typename... ArgT>
void compile() {
275 std::string InstanceCode = buildCode<RetT, ArgT...>();
276 InstanceModule = std::make_unique<CppJitModule>(
277 getInstanceTargetModel(), InstanceCode, ExtraArgs,
278 getInstanceCompilerBackend());
279 InstanceModule->compile();
281 FuncPtr = InstanceModule->getFunctionAddress(EntryFuncName);
285 FuncAttributes.set(Attr, Value);
286 InstanceModule.reset();
290 template <
typename... ArgT>
292 void *Stream, ArgT...
Args) {
293 if (!InstanceModule) {
294 compile<void, ArgT...>();
297 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
299 if (useHostLaunchEntry()) {
301 int(
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
302 std::size_t,
void *, ArgT...);
304 FuncPtr, GridDim.
X, GridDim.
Y, GridDim.
Z, BlockDim.
X, BlockDim.
Y,
305 BlockDim.
Z,
static_cast<std::size_t
>(ShmemSize), Stream,
Args...)};
308 return InstanceModule->launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize,
312 template <
typename RetOrSig,
typename... ArgT>
313 RetOrSig run(ArgT &&...
Args) {
314 static_assert(!std::is_function_v<RetOrSig>,
315 "Function signature type is not yet supported");
317 if (!InstanceModule) {
318 compile<RetOrSig, ArgT...>();
321 if constexpr (std::is_void_v<RetOrSig>)
322 InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr,
Args...);
324 return InstanceModule->Dispatch.run<RetOrSig(ArgT...)>(FuncPtr,
329 std::unordered_map<std::string, std::unique_ptr<CodeInstance>>
333 void *getFunctionAddress(
const std::string &Name);
336 uint64_t ShmemSize,
void *Stream);
341 const std::vector<std::string> &ExtraArgs = {},
344 const std::string &Target,
const std::string &Code,
345 const std::vector<std::string> &ExtraArgs = {},
361 template <
typename... ArgT>
363 std::string InstanceName = FuncName +
"<";
366 (First ?
"" :
",") + std::string(std::forward<ArgT>(
Args)),
372 auto It = InstantiationCache.find(InstanceName);
373 if (It != InstantiationCache.end()) {
377 auto [NewIt, OK] = InstantiationCache.emplace(
379 std::make_unique<CodeInstance>(TargetModel, Code, ExtraArgs,
380 CompilerBackend, InstanceName));
381 return *NewIt->second;
385 template <
typename RetT,
typename... ArgT>
390 : M(M), FuncPtr(FuncPtr) {}
393 if constexpr (std::is_void_v<RetT>) {
394 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
395 std::forward<ArgT>(
Args)...);
397 return M.Dispatch.template run<RetT(ArgT...)>(
398 FuncPtr, std::forward<ArgT>(
Args)...);
404 template <
typename RetT,
typename... ArgT>
407 void *FuncPtr =
nullptr;
409 : M(M), FuncPtr(FuncPtr) {
410 static_assert(std::is_void_v<RetT>,
"Kernel function must return void");
414 M.setFuncAttribute(FuncPtr, Attr, Value);
418 void *Stream, ArgT...
Args) {
419 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
421 return M.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize, Stream);
424 template <
typename Sig>
432 void *FuncPtr = getFunctionAddress(Name);
444 void *FuncPtr = getFunctionAddress(Name);