1#ifndef PROTEUS_FRONTEND_DISPATCHER_CUDA_HPP
2#define PROTEUS_FRONTEND_DISPATCHER_CUDA_HPP
11class DispatcherCUDA :
public Dispatcher {
13 static DispatcherCUDA &instance() {
14 static DispatcherCUDA D;
18 std::unique_ptr<MemoryBuffer>
19 compile([[maybe_unused]] std::unique_ptr<LLVMContext> Ctx,
20 std::unique_ptr<Module> M, HashT ModuleHash)
override {
23 auto LibDeviceBuffer = llvm::MemoryBuffer::getFile(LIBDEVICE_BC_PATH);
24 auto LibDeviceModule = llvm::parseBitcodeFile(
25 LibDeviceBuffer->get()->getMemBufferRef(), M->getContext());
27 llvm::Linker linker(*M);
28 linker.linkInModule(std::move(LibDeviceModule.get()));
30 std::unique_ptr<MemoryBuffer> ObjectModule =
Jit.compileOnly(*M);
34 StorageCache.store(ModuleHash, ObjectModule->getMemBufferRef());
39 std::unique_ptr<MemoryBuffer> lookupObjectModule(HashT ModuleHash)
override {
40 return StorageCache.lookup(ModuleHash);
44 LaunchDims BlockDim, ArrayRef<void *> KernelArgs,
45 uint64_t ShmemSize,
void *Stream,
46 std::optional<MemoryBufferRef> ObjectModule)
override {
47 auto *KernelFunc = getFunctionAddress(
KernelName, ObjectModule);
49 dim3 CudaGridDim = {GridDim.
X, GridDim.
Y, GridDim.
Z};
50 dim3 CudaBlockDim = {BlockDim.
X, BlockDim.
Y, BlockDim.
Z};
51 cudaStream_t CudaStream =
reinterpret_cast<cudaStream_t
>(Stream);
53 void **KernelArgsPtrs =
const_cast<void **
>(KernelArgs.data());
55 reinterpret_cast<cudaFunction_t
>(KernelFunc), CudaGridDim, CudaBlockDim,
56 KernelArgsPtrs, ShmemSize, CudaStream);
59 DispatchResult launch(
void *KernelFunc,
LaunchDims GridDim,
60 LaunchDims BlockDim, ArrayRef<void *> KernelArgs,
61 uint64_t ShmemSize,
void *Stream)
override {
62 dim3 CudaGridDim = {GridDim.
X, GridDim.
Y, GridDim.
Z};
63 dim3 CudaBlockDim = {BlockDim.
X, BlockDim.
Y, BlockDim.
Z};
64 cudaStream_t CudaStream =
reinterpret_cast<cudaStream_t
>(Stream);
66 void **KernelArgsPtrs =
const_cast<void **
>(KernelArgs.data());
68 reinterpret_cast<cudaFunction_t
>(KernelFunc), CudaGridDim, CudaBlockDim,
69 KernelArgsPtrs, ShmemSize, CudaStream);
72 StringRef getTargetArch()
const override {
return Jit.getDeviceArch(); }
76 std::optional<MemoryBufferRef> ObjectModule)
override {
77 auto GetKernelFunc = [&]() {
81 if (
auto KernelFunc = CodeCache.lookup(HashValue))
89 CodeCache.insert(HashValue, KernelFunc,
KernelName);
94 auto KernelFunc = GetKernelFunc();
99 CodeCache.printStats();
100 StorageCache.printStats();
104 JitEngineDeviceCUDA &
Jit;
105 DispatcherCUDA() :
Jit(JitEngineDeviceCUDA::instance()) {
106 TargetModel = TargetModelType::CUDA;
108 JitCache<CUfunction> CodeCache;
109 JitStorageCache<CUfunction> StorageCache;
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
auto & Jit
Definition CompilerInterfaceDevice.cpp:54
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition CppJitModule.cpp:21
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.hpp:51
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, const void * > &VarNameToDevPtr)
Definition CoreDeviceCUDA.hpp:27
Definition Dispatcher.hpp:15
unsigned Z
Definition Dispatcher.hpp:16
unsigned Y
Definition Dispatcher.hpp:16
unsigned X
Definition Dispatcher.hpp:16