1#ifndef PROTEUS_FRONTEND_DISPATCHER_HIP_H
2#define PROTEUS_FRONTEND_DISPATCHER_HIP_H
12#include <llvm/Bitcode/BitcodeReader.h>
13#include <llvm/Linker/Linker.h>
14#include <llvm/Support/FileSystem.h>
15#include <llvm/Support/MemoryBuffer.h>
19class DispatcherHIP :
public Dispatcher {
21 static DispatcherHIP &instance() {
22 static DispatcherHIP D;
26 std::unique_ptr<MemoryBuffer> compile(std::unique_ptr<LLVMContext> Ctx,
27 std::unique_ptr<Module> Mod,
28 const HashT &ModuleHash,
29 bool DisableIROpt =
false)
override {
33 auto CtxOwner = std::move(Ctx);
34 auto ModOwner = std::move(Mod);
36 auto LoadBitcode = [&](
const llvm::SmallString<256> &Path) {
37 auto BufferOrErr = llvm::MemoryBuffer::getFile(Path);
38 if (!BufferOrErr || !BufferOrErr.get())
41 auto Parsed = llvm::parseBitcodeFile(
42 BufferOrErr->get()->getMemBufferRef(), ModOwner->getContext());
46 return std::move(Parsed.get());
49 auto AppendBitcodePath =
50 [&](llvm::SmallVectorImpl<llvm::SmallString<256>> &Paths,
51 llvm::StringRef Filename) {
52 llvm::SmallString<256> Path{PROTEUS_ROCM_BITCODE_DIR};
53 llvm::sys::path::append(Path, Filename);
54 Paths.push_back(std::move(Path));
57 auto Exists = [&](llvm::StringRef Filename) ->
bool {
58 llvm::SmallString<256> Path{PROTEUS_ROCM_BITCODE_DIR};
59 llvm::sys::path::append(Path, Filename);
60 return llvm::sys::fs::exists(Path);
63 auto PickFirstExisting =
64 [&](std::initializer_list<llvm::StringRef> Candidates)
66 for (
auto C : Candidates) {
75 llvm::SmallVector<llvm::SmallString<256>, 8> LibsToLink;
76 AppendBitcodePath(LibsToLink,
"ocml.bc");
77 AppendBitcodePath(LibsToLink,
"ockl.bc");
80 if (
auto Abi = PickFirstExisting({
"oclc_abi_version_600.bc",
81 "oclc_abi_version_500.bc",
82 "oclc_abi_version_400.bc"});
84 AppendBitcodePath(LibsToLink, Abi);
87 std::string(
"DispatchHIP: missing oclc ABI bitcode under ") +
88 PROTEUS_ROCM_BITCODE_DIR +
89 " (expected oclc_abi_version_{600,500,400}.bc)");
93 const std::string DeviceArch =
Jit.getDeviceArch().str();
94 if (!llvm::StringRef{DeviceArch}.starts_with(
"gfx"))
97 const llvm::StringRef IsaSuffix = llvm::StringRef{DeviceArch}.drop_front(3);
98 const std::string IsaFile = (
"oclc_isa_version_" + IsaSuffix +
".bc").str();
101 IsaFile +
" under " + PROTEUS_ROCM_BITCODE_DIR +
102 " (DeviceArch=" + DeviceArch +
")");
103 AppendBitcodePath(LibsToLink, IsaFile);
106 AppendBitcodePath(LibsToLink,
"oclc_unsafe_math_off.bc");
107 AppendBitcodePath(LibsToLink,
"oclc_finite_only_off.bc");
108 AppendBitcodePath(LibsToLink,
"oclc_daz_opt_off.bc");
109 AppendBitcodePath(LibsToLink,
"oclc_correctly_rounded_sqrt_on.bc");
112 const bool IsWave32 = llvm::StringRef{DeviceArch}.starts_with(
"gfx10") ||
113 llvm::StringRef{DeviceArch}.starts_with(
"gfx11") ||
114 llvm::StringRef{DeviceArch}.starts_with(
"gfx12");
115 AppendBitcodePath(LibsToLink, IsWave32 ?
"oclc_wavefrontsize64_off.bc"
116 :
"oclc_wavefrontsize64_on.bc");
118 llvm::Linker Linker{*ModOwner};
119 for (
const auto &Path : LibsToLink) {
120 auto LibMod = LoadBitcode(Path);
121 Linker.linkInModule(std::move(LibMod),
122 llvm::Linker::Flags::LinkOnlyNeeded);
125 std::unique_ptr<MemoryBuffer> ObjectModule =
126 Jit.compileOnly(*ModOwner, DisableIROpt);
131 ModuleHash, CacheEntry::staticObject(ObjectModule->getMemBufferRef()));
136 std::unique_ptr<CompiledLibrary>
137 lookupCompiledLibrary(
const HashT &ModuleHash)
override {
138 return ObjectCache->lookup(ModuleHash);
141 DispatchResult launch(
void *KernelFunc,
LaunchDims GridDim,
143 uint64_t ShmemSize,
void *Stream)
override {
145 dim3 HipGridDim = {GridDim.
X, GridDim.
Y, GridDim.
Z};
146 dim3 HipBlockDim = {BlockDim.
X, BlockDim.
Y, BlockDim.
Z};
147 hipStream_t HipStream =
reinterpret_cast<hipStream_t
>(Stream);
150 reinterpret_cast<hipFunction_t
>(KernelFunc), HipGridDim, HipBlockDim,
151 KernelArgs, ShmemSize, HipStream);
154 StringRef getDeviceArch()
const override {
return Jit.getDeviceArch(); }
157 CodeCache.printStats();
158 CodeCache.printKernelTrace();
159 ObjectCache->printStats();
162 void *getFunctionAddress(
const std::string &
KernelName,
163 const HashT &ModuleHash,
164 CompiledLibrary &Library)
override {
165 TIMESCOPE(DispatcherHIP, getFunctionAddress);
166 auto GetKernelFunc = [&]() {
170 if (
auto KernelFunc = CodeCache.lookup(HashValue))
174 KernelName, Library.ObjectModule->getBufferStart(),
178 CodeCache.insert(HashValue, KernelFunc,
KernelName);
183 auto KernelFunc = GetKernelFunc();
187 void registerDynamicLibrary(
const HashT &,
const std::string &)
override {
191 void registerObject(
const HashT &HashValue,
192 const llvm::MemoryBufferRef &Obj)
override {
193 ObjectCache->store(HashValue, CacheEntry::staticObject(Obj));
197 JitEngineDeviceHIP &
Jit;
200 Jit(JitEngineDeviceHIP::instance()) {}
201 MemoryCache<hipFunction_t> CodeCache{
"DispatcherHIP"};
void char * KernelName
Definition CompilerInterfaceDevice.cpp:55
JitEngineHost & Jit
Definition CompilerInterfaceHost.cpp:26
#define TIMESCOPE(...)
Definition TimeTracing.h:66
Definition MemoryCache.h:27
TargetModelType
Definition TargetModel.h:8
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.h:168
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
cudaError_t launchKernelFunction(CUfunction KernelFunc, dim3 GridDim, dim3 BlockDim, void **KernelArgs, uint64_t ShmemSize, CUstream Stream)
Definition CoreDeviceCUDA.h:78
CUfunction getKernelFunctionFromImage(StringRef KernelName, const void *Image, bool RelinkGlobalsByCopy, const std::unordered_map< std::string, GlobalVarInfo > &VarNameToGlobalInfo)
Definition CoreDeviceCUDA.h:50
Definition Dispatcher.h:22
unsigned Z
Definition Dispatcher.h:23
unsigned Y
Definition Dispatcher.h:23
unsigned X
Definition Dispatcher.h:23