Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
Dispatcher.h
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_DISPATCHER_H
2#define PROTEUS_FRONTEND_DISPATCHER_H
3
4#include "proteus/Error.h"
6
7#if PROTEUS_ENABLE_HIP && __HIP__
8#include <hip/hip_runtime.h>
9#endif
10
11#include <cstdint>
12#include <memory>
13#include <type_traits>
14
15namespace llvm {
16class LLVMContext;
17class Module;
18class MemoryBuffer;
19} // namespace llvm
20
21struct LaunchDims {
22 unsigned X = 1, Y = 1, Z = 1;
23
24 constexpr LaunchDims() = default;
25
26 constexpr LaunchDims(unsigned X, unsigned Y = 1, unsigned Z = 1)
27 : X(X), Y(Y), Z(Z) {}
28
29 // Templated converting constructor for dim3-like types.
30 template <
31 typename T,
32 typename = std::enable_if_t<
33 std::is_convertible_v<decltype(std::declval<T>().x), unsigned> &&
34 std::is_convertible_v<decltype(std::declval<T>().y), unsigned> &&
35 std::is_convertible_v<decltype(std::declval<T>().z), unsigned>>>
36 constexpr LaunchDims(const T &Dims) : X(Dims.x), Y(Dims.y), Z(Dims.z) {}
37};
38
39namespace proteus {
40
41class ObjectCacheChain;
42struct CompiledLibrary;
43class HashT;
44
45template <typename T> struct sig_traits;
46
47template <typename R, typename... Args> struct sig_traits<R(Args...)> {
48 using return_type = R;
49 using argument_types = std::tuple<Args...>;
50};
51
53 int Ret;
54
55 // construct from an integer error‐code
56 constexpr DispatchResult(int Ret = 0) noexcept : Ret(Ret) {}
57
58 // implicit conversion back to int
59 operator int() const noexcept { return Ret; }
60
61#if PROTEUS_ENABLE_HIP && __HIP__
62 operator hipError_t() const noexcept { return static_cast<hipError_t>(Ret); }
63#endif
64
65#if PROTEUS_ENABLE_CUDA && defined(__CUDACC__)
66 operator cudaError_t() const noexcept {
67 return static_cast<cudaError_t>(Ret);
68 }
69#endif
70};
71
72struct DispatchResult;
73
75protected:
77 std::unique_ptr<ObjectCacheChain> ObjectCache;
78
79 Dispatcher(const std::string &Name, TargetModelType TM);
80
81public:
82 static Dispatcher &getDispatcher(TargetModelType TargetModel);
83 virtual ~Dispatcher() = default;
84
85 virtual std::unique_ptr<llvm::MemoryBuffer>
86 compile(std::unique_ptr<llvm::LLVMContext> Ctx,
87 std::unique_ptr<llvm::Module> M, const HashT &ModuleHash,
88 bool DisableIROpt = false) = 0;
89
90 virtual std::unique_ptr<CompiledLibrary>
91 lookupCompiledLibrary(const HashT &ModuleHash) = 0;
92
93 virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
94 LaunchDims BlockDim, void *KernelArgs[],
95 uint64_t ShmemSize, void *Stream) = 0;
96
97 virtual llvm::StringRef getDeviceArch() const = 0;
98
99 template <typename Sig, typename... ArgT>
100 typename sig_traits<Sig>::return_type run(void *FuncPtr, ArgT &&...Args) {
101 if (!isHostTargetModel(TargetModel))
103 "Dispatcher run interface is only supported for host derived models");
104
105 auto Fn = reinterpret_cast<Sig *>(FuncPtr);
106 using Ret = typename sig_traits<Sig>::return_type;
107
108 if constexpr (std::is_void_v<Ret>) {
109 Fn(std::forward<ArgT>(Args)...);
110 return;
111 } else
112 return Fn(std::forward<ArgT>(Args)...);
113 }
114
115 virtual void *getFunctionAddress(const std::string &FunctionName,
116 const HashT &ModuleHash,
117 CompiledLibrary &Library) = 0;
118
119 virtual void registerDynamicLibrary(const HashT &HashValue,
120 const std::string &Path) = 0;
121};
122
123} // namespace proteus
124
125#endif
char int void ** Args
Definition CompilerInterfaceHost.cpp:22
Definition Dispatcher.h:74
TargetModelType TargetModel
Definition Dispatcher.h:76
virtual void registerDynamicLibrary(const HashT &HashValue, const std::string &Path)=0
virtual std::unique_ptr< llvm::MemoryBuffer > compile(std::unique_ptr< llvm::LLVMContext > Ctx, std::unique_ptr< llvm::Module > M, const HashT &ModuleHash, bool DisableIROpt=false)=0
std::unique_ptr< ObjectCacheChain > ObjectCache
Definition Dispatcher.h:77
virtual ~Dispatcher()=default
virtual std::unique_ptr< CompiledLibrary > lookupCompiledLibrary(const HashT &ModuleHash)=0
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, void *KernelArgs[], uint64_t ShmemSize, void *Stream)=0
sig_traits< Sig >::return_type run(void *FuncPtr, ArgT &&...Args)
Definition Dispatcher.h:100
virtual llvm::StringRef getDeviceArch() const =0
virtual void * getFunctionAddress(const std::string &FunctionName, const HashT &ModuleHash, CompiledLibrary &Library)=0
Definition Hashing.h:21
Definition CompiledLibrary.h:7
Definition MemoryCache.h:26
TargetModelType
Definition TargetModel.h:8
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.cpp:49
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:21
constexpr LaunchDims(unsigned X, unsigned Y=1, unsigned Z=1)
Definition Dispatcher.h:26
unsigned Z
Definition Dispatcher.h:22
unsigned Y
Definition Dispatcher.h:22
constexpr LaunchDims()=default
constexpr LaunchDims(const T &Dims)
Definition Dispatcher.h:36
unsigned X
Definition Dispatcher.h:22
Definition CompiledLibrary.h:18
Definition Dispatcher.h:52
constexpr DispatchResult(int Ret=0) noexcept
Definition Dispatcher.h:56
int Ret
Definition Dispatcher.h:53
R return_type
Definition Dispatcher.h:48
std::tuple< Args... > argument_types
Definition Dispatcher.h:49
Definition Dispatcher.h:45