1#ifndef PROTEUS_FRONTEND_DISPATCHER_HPP
2#define PROTEUS_FRONTEND_DISPATCHER_HPP
4#include <llvm/IR/Module.h>
5#include <llvm/Support/MemoryBuffer.h>
11#if PROTEUS_ENABLE_HIP && __HIP__
12#include <hip/hip_runtime.h>
16 unsigned X = 1,
Y = 1,
Z = 1;
31 operator int() const noexcept {
return Ret; }
33#if PROTEUS_ENABLE_HIP && __HIP__
34 operator hipError_t() const noexcept {
return static_cast<hipError_t
>(
Ret); }
37#if PROTEUS_ENABLE_CUDA && defined(__CUDACC__)
38 operator cudaError_t() const noexcept {
39 return static_cast<cudaError_t
>(
Ret);
53 virtual std::unique_ptr<MemoryBuffer>
54 compile(std::unique_ptr<LLVMContext> Ctx, std::unique_ptr<Module> M,
55 HashT ModuleHash) = 0;
57 virtual std::unique_ptr<MemoryBuffer>
62 ArrayRef<void *> KernelArgs, uint64_t ShmemSize,
void *Stream,
63 std::optional<MemoryBufferRef> ObjectModule) = 0;
67 ArrayRef<void *> KernelArgs, uint64_t ShmemSize,
74 template <
typename RetOrSig,
typename... ArgT>
75 auto run(StringRef FuncName, std::optional<MemoryBufferRef> ObjectModule,
79 "Dispatcher run interface is only supported for host");
83 if constexpr (std::is_function_v<RetOrSig>) {
84 auto Fn =
reinterpret_cast<RetOrSig *
>(Addr);
85 using Ret = std::invoke_result_t<RetOrSig, ArgT...>;
87 if constexpr (std::is_void_v<Ret>) {
88 Fn(std::forward<ArgT>(
Args)...);
91 return Fn(std::forward<ArgT>(
Args)...);
93 using FnPtr = RetOrSig (*)(std::decay_t<ArgT>...);
94 auto Fn =
reinterpret_cast<FnPtr
>(Addr);
96 if constexpr (std::is_void_v<RetOrSig>) {
97 Fn(std::forward<ArgT>(
Args)...);
100 return Fn(std::forward<ArgT>(
Args)...);
106 std::optional<MemoryBufferRef> ObjectModule) = 0;
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
char int void ** Args
Definition CompilerInterfaceHost.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition Dispatcher.hpp:46
TargetModelType TargetModel
Definition Dispatcher.hpp:48
virtual std::unique_ptr< MemoryBuffer > lookupObjectModule(HashT ModuleHash)=0
static Dispatcher & getDispatcher(TargetModelType TargetModel)
Definition Dispatcher.cpp:15
virtual std::unique_ptr< MemoryBuffer > compile(std::unique_ptr< LLVMContext > Ctx, std::unique_ptr< Module > M, HashT ModuleHash)=0
virtual DispatchResult launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream, std::optional< MemoryBufferRef > ObjectModule)=0
auto run(StringRef FuncName, std::optional< MemoryBufferRef > ObjectModule, ArgT &&...Args)
Definition Dispatcher.hpp:75
virtual void * getFunctionAddress(StringRef FunctionName, std::optional< MemoryBufferRef > ObjectModule)=0
virtual StringRef getTargetArch() const =0
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)=0
Definition Hashing.hpp:20
Definition CppJitModule.cpp:21
TargetModelType
Definition TargetModel.hpp:14
Definition Dispatcher.hpp:15
unsigned Z
Definition Dispatcher.hpp:16
unsigned Y
Definition Dispatcher.hpp:16
unsigned X
Definition Dispatcher.hpp:16
Definition Dispatcher.hpp:24
constexpr DispatchResult(int Ret=0) noexcept
Definition Dispatcher.hpp:28
int Ret
Definition Dispatcher.hpp:25