Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
CoreLLVMCUDA.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_CORE_LLVM_CUDA_HPP
2#define PROTEUS_CORE_LLVM_CUDA_HPP
3
4#include <llvm/ADT/SmallVector.h>
5#include <llvm/ADT/StringRef.h>
6#include <llvm/CodeGen/MachineModuleInfo.h>
7#include <llvm/IR/LegacyPassManager.h>
8#include <llvm/IR/Module.h>
9#include <llvm/Support/MemoryBufferRef.h>
10#include <llvm/Support/TargetSelect.h>
11#include <llvm/Target/TargetMachine.h>
12
14#include "proteus/Logger.hpp"
15#include "proteus/UtilsCUDA.h"
16
17namespace proteus {
18
19using namespace llvm;
20
21namespace detail {
22
23inline const SmallVector<StringRef> &gridDimXFnName() {
24 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.nctaid.x"};
25 return Names;
26}
27
28inline const SmallVector<StringRef> &gridDimYFnName() {
29 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.nctaid.y"};
30 return Names;
31}
32
33inline const SmallVector<StringRef> &gridDimZFnName() {
34 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.nctaid.z"};
35 return Names;
36}
37
38inline const SmallVector<StringRef> &blockDimXFnName() {
39 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.ntid.x"};
40 return Names;
41}
42
43inline const SmallVector<StringRef> &blockDimYFnName() {
44 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.ntid.y"};
45 return Names;
46}
47
48inline const SmallVector<StringRef> &blockDimZFnName() {
49 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.ntid.z"};
50 return Names;
51}
52
53inline const SmallVector<StringRef> &blockIdxXFnName() {
54 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.ctaid.x"};
55 return Names;
56}
57
58inline const SmallVector<StringRef> &blockIdxYFnName() {
59 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.ctaid.y"};
60 return Names;
61}
62
63inline const SmallVector<StringRef> &blockIdxZFnName() {
64 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.ctaid.z"};
65 return Names;
66}
67
68inline const SmallVector<StringRef> &threadIdxXFnName() {
69 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.tid.x"};
70 return Names;
71}
72
73inline const SmallVector<StringRef> &threadIdxYFnName() {
74 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.tid.y"};
75 return Names;
76}
77
78inline const SmallVector<StringRef> &threadIdxZFnName() {
79 static SmallVector<StringRef> Names = {"llvm.nvvm.read.ptx.sreg.tid.z"};
80 return Names;
81}
82
83} // namespace detail
84
85inline void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize,
86 int BlockSize) {
87 NamedMDNode *NvvmAnnotations = M.getNamedMetadata("nvvm.annotations");
88 assert(NvvmAnnotations && "Expected non-null nvvm.annotations metadata");
89 // TODO: fix hardcoded 1024 as the maximum, by reading device
90 // properties.
91 // TODO: set min GridSize.
92 int MaxThreads = std::min(1024, BlockSize);
93 auto *FuncMetadata = ConstantAsMetadata::get(&F);
94 auto *MaxntidxMetadata = MDString::get(M.getContext(), "maxntidx");
95 auto *MaxThreadsMetadata = ConstantAsMetadata::get(
96 ConstantInt::get(Type::getInt32Ty(M.getContext()), MaxThreads));
97
98 // Replace if the metadata exists.
99 for (auto *MetadataNode : NvvmAnnotations->operands()) {
100 // Expecting 3 operands ptr, desc, i32 value.
101 assert(MetadataNode->getNumOperands() == 3);
102
103 auto *PtrMetadata = MetadataNode->getOperand(0).get();
104 auto *DescMetadata = MetadataNode->getOperand(1).get();
105 if (PtrMetadata == FuncMetadata && MaxntidxMetadata == DescMetadata) {
106 MetadataNode->replaceOperandWith(2, MaxThreadsMetadata);
107 return;
108 }
109 }
110
111 // Otherwise create the metadata and insert.
112 Metadata *MDVals[] = {FuncMetadata, MaxntidxMetadata, MaxThreadsMetadata};
113 NvvmAnnotations->addOperand(MDNode::get(M.getContext(), MDVals));
114}
115inline void codegenPTX(Module &M, StringRef DeviceArch,
116 SmallVectorImpl<char> &PTXStr) {
117 // TODO: It is possbile to use PTX directly through the CUDA PTX JIT
118 // interface. Maybe useful if we can re-link globals using the CUDA API.
119 // Check this reference for PTX JIT caching:
120 // https://developer.nvidia.com/blog/cuda-pro-tip-understand-fat-binaries-jit-caching/
121 // Interesting env vars: CUDA_CACHE_DISABLE, CUDA_CACHE_MAXSIZE,
122 // CUDA_CACHE_PATH, CUDA_FORCE_PTX_JIT.
123 auto TMExpected = proteus::detail::createTargetMachine(M, DeviceArch);
124 if (!TMExpected)
125 PROTEUS_FATAL_ERROR(toString(TMExpected.takeError()));
126
127 std::unique_ptr<TargetMachine> TM = std::move(*TMExpected);
128 TargetLibraryInfoImpl TLII(Triple(M.getTargetTriple()));
129
130 legacy::PassManager PM;
131 PM.add(new TargetLibraryInfoWrapperPass(TLII));
132 MachineModuleInfoWrapperPass *MMIWP = new MachineModuleInfoWrapperPass(
133 reinterpret_cast<LLVMTargetMachine *>(TM.get()));
134
135 raw_svector_ostream PTXOS(PTXStr);
136#if LLVM_VERSION_MAJOR >= 18
137 TM->addPassesToEmitFile(PM, PTXOS, nullptr, CodeGenFileType::AssemblyFile,
138 /* DisableVerify */ false, MMIWP);
139#else
140 TM->addPassesToEmitFile(PM, PTXOS, nullptr, CGFT_AssemblyFile,
141 /* DisableVerify */ false, MMIWP);
142#endif
143
144 PM.run(M);
145}
146
147inline std::unique_ptr<MemoryBuffer>
148codegenObject(Module &M, StringRef DeviceArch,
149 SmallPtrSetImpl<void *> &GlobalLinkedBinaries,
150 [[maybe_unused]] bool UseRTC = true) {
151 assert(UseRTC && "Expected RTC compilation true for CUDA");
152 SmallVector<char, 4096> PTXStr;
153 size_t BinSize;
154
155 codegenPTX(M, DeviceArch, PTXStr);
156 PTXStr.push_back('\0');
157
158 nvPTXCompilerHandle PTXCompiler;
160 nvPTXCompilerCreate(&PTXCompiler, PTXStr.size(), PTXStr.data()));
161 std::string ArchOpt = ("--gpu-name=" + DeviceArch).str();
162 std::string RDCOption = "";
163 if (!GlobalLinkedBinaries.empty())
164 RDCOption = "-c";
165#if PROTEUS_ENABLE_DEBUG
166 const char *CompileOptions[] = {ArchOpt.c_str(), "--verbose",
167 RDCOption.c_str()};
168 size_t NumCompileOptions = 2 + (RDCOption.empty() ? 0 : 1);
169#else
170 const char *CompileOptions[] = {ArchOpt.c_str(), RDCOption.c_str()};
171 size_t NumCompileOptions = 1 + (RDCOption.empty() ? 0 : 1);
172#endif
174 nvPTXCompilerCompile(PTXCompiler, NumCompileOptions, CompileOptions));
176 nvPTXCompilerGetCompiledProgramSize(PTXCompiler, &BinSize));
177 auto ObjBuf = WritableMemoryBuffer::getNewUninitMemBuffer(BinSize);
179 nvPTXCompilerGetCompiledProgram(PTXCompiler, ObjBuf->getBufferStart()));
180#if PROTEUS_ENABLE_DEBUG
181 {
182 size_t LogSize;
184 nvPTXCompilerGetInfoLogSize(PTXCompiler, &LogSize));
185 auto Log = std::make_unique<char[]>(LogSize);
187 nvPTXCompilerGetInfoLog(PTXCompiler, Log.get()));
188 Logger::logs("proteus") << "=== nvPTXCompiler Log\n" << Log.get() << "\n";
189 }
190#endif
191 proteusNvPTXCompilerErrCheck(nvPTXCompilerDestroy(&PTXCompiler));
192
193 std::unique_ptr<MemoryBuffer> FinalObjBuf;
194 if (!GlobalLinkedBinaries.empty()) {
195 // Create CUDA context if needed. This is required by threaded async
196 // compilation.
197 CUcontext CUCtx;
198 proteusCuErrCheck(cuCtxGetCurrent(&CUCtx));
199 if (!CUCtx) {
200 CUdevice CUDev;
201 CUresult CURes = cuCtxGetDevice(&CUDev);
202 if (CURes == CUDA_ERROR_INVALID_CONTEXT or !CUDev)
203 proteusCuErrCheck(cuDeviceGet(&CUDev, 0));
204
205 proteusCuErrCheck(cuCtxGetCurrent(&CUCtx));
206 proteusCuErrCheck(cuCtxCreate(&CUCtx, 0, CUDev));
207 }
208
209 // TODO: re-implement using the more recent nvJitLink interface.
210 CUlinkState CULinkState;
211 proteusCuErrCheck(cuLinkCreate(0, nullptr, nullptr, &CULinkState));
212 for (auto *Ptr : GlobalLinkedBinaries) {
213 // We do not know the size of the binary but the CUDA API just needs a
214 // non-zero argument.
215 proteusCuErrCheck(cuLinkAddData(CULinkState, CU_JIT_INPUT_FATBINARY, Ptr,
216 1, "", 0, 0, 0));
217 }
218
219 // Again using a non-zero argument, though we can get the size from the ptx
220 // compiler.
221 proteusCuErrCheck(cuLinkAddData(
222 CULinkState, CU_JIT_INPUT_FATBINARY,
223 static_cast<void *>(ObjBuf->getBufferStart()), 1, "", 0, 0, 0));
224
225 void *BinOut;
226 size_t BinSize;
227 proteusCuErrCheck(cuLinkComplete(CULinkState, &BinOut, &BinSize));
228 FinalObjBuf = MemoryBuffer::getMemBufferCopy(
229 StringRef{static_cast<char *>(BinOut), BinSize});
230 } else {
231 FinalObjBuf = std::move(ObjBuf);
232 }
233
234 return FinalObjBuf;
235}
236
237} // namespace proteus
238
239#endif
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
#define proteusNvPTXCompilerErrCheck(CALL)
Definition UtilsCUDA.h:39
#define proteusCuErrCheck(CALL)
Definition UtilsCUDA.h:28
static llvm::raw_ostream & logs(const std::string &Name)
Definition Logger.hpp:18
const SmallVector< StringRef > & threadIdxXFnName()
Definition CoreLLVMCUDA.hpp:68
const SmallVector< StringRef > & gridDimYFnName()
Definition CoreLLVMCUDA.hpp:28
const SmallVector< StringRef > & threadIdxZFnName()
Definition CoreLLVMCUDA.hpp:78
const SmallVector< StringRef > & blockIdxZFnName()
Definition CoreLLVMCUDA.hpp:63
const SmallVector< StringRef > & gridDimZFnName()
Definition CoreLLVMCUDA.hpp:33
const SmallVector< StringRef > & gridDimXFnName()
Definition CoreLLVMCUDA.hpp:23
const SmallVector< StringRef > & blockIdxXFnName()
Definition CoreLLVMCUDA.hpp:53
Expected< std::unique_ptr< TargetMachine > > createTargetMachine(Module &M, StringRef Arch, unsigned OptLevel=3)
Definition CoreLLVM.hpp:48
const SmallVector< StringRef > & threadIdxYFnName()
Definition CoreLLVMCUDA.hpp:73
const SmallVector< StringRef > & blockIdxYFnName()
Definition CoreLLVMCUDA.hpp:58
const SmallVector< StringRef > & blockDimYFnName()
Definition CoreLLVMCUDA.hpp:43
const SmallVector< StringRef > & blockDimZFnName()
Definition CoreLLVMCUDA.hpp:48
const SmallVector< StringRef > & blockDimXFnName()
Definition CoreLLVMCUDA.hpp:38
Definition JitEngine.cpp:20
void setLaunchBoundsForKernel(Module &M, Function &F, size_t GridSize, int BlockSize)
Definition CoreLLVMCUDA.hpp:85
void codegenPTX(Module &M, StringRef DeviceArch, SmallVectorImpl< char > &PTXStr)
Definition CoreLLVMCUDA.hpp:115
std::unique_ptr< MemoryBuffer > codegenObject(Module &M, StringRef DeviceArch, SmallPtrSetImpl< void * > &GlobalLinkedBinaries, bool UseRTC=true)
Definition CoreLLVMCUDA.hpp:148