Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
Dispatcher.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_HPP
2#define PROTEUS_FRONTEND_DISPATCHER_HPP
3
4#include <llvm/IR/Module.h>
5#include <llvm/Support/MemoryBuffer.h>
6#include <memory>
7
9#include "proteus/Hashing.hpp"
10
11#if PROTEUS_ENABLE_HIP && __HIP__
12#include <hip/hip_runtime.h>
13#endif
14
15struct LaunchDims {
16 unsigned X = 1, Y = 1, Z = 1;
17};
18
19namespace proteus {
20
21using namespace llvm;
22
23// in Dispatcher.hpp (or a new Errors.hpp)
25 int Ret;
26
27 // construct from an integer error‐code
28 constexpr DispatchResult(int Ret = 0) noexcept : Ret(Ret) {}
29
30 // implicit conversion back to int
31 operator int() const noexcept { return Ret; }
32
33#if PROTEUS_ENABLE_HIP && __HIP__
34 operator hipError_t() const noexcept { return static_cast<hipError_t>(Ret); }
35#endif
36
37#if PROTEUS_ENABLE_CUDA && defined(__CUDACC__)
38 operator cudaError_t() const noexcept {
39 return static_cast<cudaError_t>(Ret);
40 }
41#endif
42};
43
44struct DispatchResult;
45
47protected:
49
50public:
52
53 virtual std::unique_ptr<MemoryBuffer>
54 compile(std::unique_ptr<LLVMContext> Ctx, std::unique_ptr<Module> M,
55 HashT ModuleHash) = 0;
56
57 virtual std::unique_ptr<MemoryBuffer>
58 lookupObjectModule(HashT ModuleHash) = 0;
59
60 virtual DispatchResult
61 launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim,
62 ArrayRef<void *> KernelArgs, uint64_t ShmemSize, void *Stream,
63 std::optional<MemoryBufferRef> ObjectModule) = 0;
64
65 virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
66 LaunchDims BlockDim,
67 ArrayRef<void *> KernelArgs, uint64_t ShmemSize,
68 void *Stream) = 0;
69
70 virtual StringRef getTargetArch() const = 0;
71
72 // Accepts both a return type or a function signature (needed for C++
73 // reference arguments) and disambiguates at compile time.
74 template <typename RetOrSig, typename... ArgT>
75 auto run(StringRef FuncName, std::optional<MemoryBufferRef> ObjectModule,
76 ArgT &&...Args) {
79 "Dispatcher run interface is only supported for host");
80
81 void *Addr = getFunctionAddress(FuncName, ObjectModule);
82
83 if constexpr (std::is_function_v<RetOrSig>) {
84 auto Fn = reinterpret_cast<RetOrSig *>(Addr);
85 using Ret = std::invoke_result_t<RetOrSig, ArgT...>;
86
87 if constexpr (std::is_void_v<Ret>) {
88 Fn(std::forward<ArgT>(Args)...);
89 return;
90 } else
91 return Fn(std::forward<ArgT>(Args)...);
92 } else {
93 using FnPtr = RetOrSig (*)(std::decay_t<ArgT>...);
94 auto Fn = reinterpret_cast<FnPtr>(Addr);
95
96 if constexpr (std::is_void_v<RetOrSig>) {
97 Fn(std::forward<ArgT>(Args)...);
98 return;
99 } else
100 return Fn(std::forward<ArgT>(Args)...);
101 }
102 }
103
104 virtual void *
105 getFunctionAddress(StringRef FunctionName,
106 std::optional<MemoryBufferRef> ObjectModule) = 0;
107};
108
109} // namespace proteus
110
111#endif
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
char int void ** Args
Definition CompilerInterfaceHost.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition Dispatcher.hpp:46
TargetModelType TargetModel
Definition Dispatcher.hpp:48
virtual std::unique_ptr< MemoryBuffer > lookupObjectModule(HashT ModuleHash)=0
static Dispatcher & getDispatcher(TargetModelType TargetModel)
Definition Dispatcher.cpp:15
virtual std::unique_ptr< MemoryBuffer > compile(std::unique_ptr< LLVMContext > Ctx, std::unique_ptr< Module > M, HashT ModuleHash)=0
virtual DispatchResult launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream, std::optional< MemoryBufferRef > ObjectModule)=0
auto run(StringRef FuncName, std::optional< MemoryBufferRef > ObjectModule, ArgT &&...Args)
Definition Dispatcher.hpp:75
virtual void * getFunctionAddress(StringRef FunctionName, std::optional< MemoryBufferRef > ObjectModule)=0
virtual StringRef getTargetArch() const =0
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)=0
Definition Hashing.hpp:20
Definition Helpers.h:76
Definition CppJitModule.cpp:21
TargetModelType
Definition TargetModel.hpp:14
Definition Dispatcher.hpp:15
unsigned Z
Definition Dispatcher.hpp:16
unsigned Y
Definition Dispatcher.hpp:16
unsigned X
Definition Dispatcher.hpp:16
Definition Dispatcher.hpp:24
constexpr DispatchResult(int Ret=0) noexcept
Definition Dispatcher.hpp:28
int Ret
Definition Dispatcher.hpp:25