1#ifndef PROTEUS_FRONTEND_DISPATCHER_HIP_HPP
2#define PROTEUS_FRONTEND_DISPATCHER_HIP_HPP
12class DispatcherHIP :
public Dispatcher {
14 static DispatcherHIP &instance() {
15 static DispatcherHIP D;
19 std::unique_ptr<MemoryBuffer>
20 compile([[maybe_unused]] std::unique_ptr<LLVMContext> Ctx,
21 std::unique_ptr<Module> M, HashT ModuleHash)
override {
22 std::unique_ptr<MemoryBuffer> ObjectModule =
Jit.compileOnly(*M);
26 StorageCache.store(ModuleHash, ObjectModule->getMemBufferRef());
31 std::unique_ptr<MemoryBuffer> lookupObjectModule(HashT ModuleHash)
override {
32 return StorageCache.lookup(ModuleHash);
36 LaunchDims BlockDim, ArrayRef<void *> KernelArgs,
37 uint64_t ShmemSize,
void *Stream,
38 std::optional<MemoryBufferRef> ObjectModule)
override {
39 auto *KernelFunc = getFunctionAddress(
KernelName, ObjectModule);
41 dim3 HipGridDim = {GridDim.
X, GridDim.
Y, GridDim.
Z};
42 dim3 HipBlockDim = {BlockDim.
X, BlockDim.
Y, BlockDim.
Z};
43 hipStream_t HipStream =
reinterpret_cast<hipStream_t
>(Stream);
45 void **KernelArgsPtrs =
const_cast<void **
>(KernelArgs.data());
47 reinterpret_cast<hipFunction_t
>(KernelFunc), HipGridDim, HipBlockDim,
48 KernelArgsPtrs, ShmemSize, HipStream);
51 DispatchResult launch(
void *KernelFunc,
LaunchDims GridDim,
52 LaunchDims BlockDim, ArrayRef<void *> KernelArgs,
53 uint64_t ShmemSize,
void *Stream)
override {
54 dim3 HipGridDim = {GridDim.
X, GridDim.
Y, GridDim.
Z};
55 dim3 HipBlockDim = {BlockDim.
X, BlockDim.
Y, BlockDim.
Z};
56 hipStream_t HipStream =
reinterpret_cast<hipStream_t
>(Stream);
58 void **KernelArgsPtrs =
const_cast<void **
>(KernelArgs.data());
60 reinterpret_cast<hipFunction_t
>(KernelFunc), HipGridDim, HipBlockDim,
61 KernelArgsPtrs, ShmemSize, HipStream);
64 StringRef getTargetArch()
const override {
return Jit.getDeviceArch(); }
67 CodeCache.printStats();
68 StorageCache.printStats();
73 std::optional<MemoryBufferRef> ObjectModule)
override {
74 auto GetKernelFunc = [&]() {
78 if (
auto KernelFunc = CodeCache.lookup(HashValue))
86 CodeCache.insert(HashValue, KernelFunc,
KernelName);
91 auto KernelFunc = GetKernelFunc();
96 JitEngineDeviceHIP &
Jit;
97 DispatcherHIP() :
Jit(JitEngineDeviceHIP::instance()) {
98 TargetModel = TargetModelType::HIP;
100 JitCache<hipFunction_t> CodeCache;
101 JitStorageCache<hipFunction_t> StorageCache;
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
auto & Jit
Definition CompilerInterfaceDevice.cpp:54
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition CppJitModule.cpp:21
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:51
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, const void * > &VarNameToDevPtr)
Definition CoreDeviceCUDA.hpp:27
Definition Dispatcher.hpp:15
unsigned Z
Definition Dispatcher.hpp:16
unsigned Y
Definition Dispatcher.hpp:16
unsigned X
Definition Dispatcher.hpp:16