Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
LLVMIRJitModule.h
Go to the documentation of this file.
1#ifndef PROTEUS_LLVMIRJITMODULE_H
2#define PROTEUS_LLVMIRJITMODULE_H
3
7
8#include <memory>
9#include <string>
10#include <type_traits>
11#include <utility>
12
13namespace proteus {
14
15struct CompiledLibrary;
16class HashT;
17
18enum class LLVMIRInputKind {
19 Auto,
20 TextIR,
21 Bitcode,
22};
23
25private:
26 TargetModelType TargetModel;
27 std::string Code;
28 LLVMIRInputKind InputKind;
29 Dispatcher &Dispatch;
30 std::unique_ptr<HashT> ModuleHash;
31 std::unique_ptr<CompiledLibrary> Library;
32 bool IsCompiled = false;
33
34public:
35 explicit LLVMIRJitModule(TargetModelType TargetModel, const std::string &Code,
37 explicit LLVMIRJitModule(const std::string &Target, const std::string &Code,
40
41 void compile(bool Verify = false);
42
43 // Expose the target model so higher-level bindings can reject invalid API
44 // combinations before dispatch.
45 TargetModelType getTargetModel() const { return TargetModel; }
46
47 void setFuncAttribute(void *KernelFunc, JitFuncAttribute Attr, int Value);
48 void *getFunctionAddress(const std::string &Name);
49 void *getKernelAddress(const std::string &Name) {
50 if (!IsCompiled)
51 compile();
52
53 if (TargetModel == TargetModelType::HOST)
55 "Error: getKernelAddress() applies only to device modules");
56
57 // Kernel symbols are loaded through the same compiled image as host
58 // function symbols once target validation is complete.
59 return getFunctionAddress(Name);
60 }
61
62 DispatchResult launch(void *KernelFunc, LaunchDims GridDim,
63 LaunchDims BlockDim, void *KernelArgs[],
64 uint64_t ShmemSize, void *Stream);
65
67 if (!IsCompiled)
68 compile();
69
70 if (!Library)
71 reportFatalError("Expected non-null library after compilation");
72
73 return *Library;
74 }
75
76 template <typename Sig> struct FunctionHandle;
77 template <typename RetT, typename... ArgT>
78 struct FunctionHandle<RetT(ArgT...)> {
80 void *FuncPtr;
81
82 explicit FunctionHandle(LLVMIRJitModule &M, void *FuncPtr)
83 : M(M), FuncPtr(FuncPtr) {}
84
85 RetT run(ArgT... Args) {
86 if constexpr (std::is_void_v<RetT>) {
87 M.Dispatch.template run<RetT(ArgT...)>(FuncPtr,
88 std::forward<ArgT>(Args)...);
89 } else {
90 return M.Dispatch.template run<RetT(ArgT...)>(
91 FuncPtr, std::forward<ArgT>(Args)...);
92 }
93 }
94 };
95
96 template <typename Sig> struct KernelHandle;
97 template <typename RetT, typename... ArgT>
98 struct KernelHandle<RetT(ArgT...)> {
100 void *FuncPtr = nullptr;
101
102 explicit KernelHandle(LLVMIRJitModule &M, void *FuncPtr)
103 : M(M), FuncPtr(FuncPtr) {
104 static_assert(std::is_void_v<RetT>, "Kernel function must return void");
105 }
106
107 void setFuncAttribute(JitFuncAttribute Attr, int Value) {
108 M.setFuncAttribute(FuncPtr, Attr, Value);
109 }
110
111 auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize,
112 void *Stream, ArgT... Args) {
113 void *Ptrs[sizeof...(ArgT)] = {(void *)&Args...};
114 return M.launch(FuncPtr, GridDim, BlockDim, Ptrs, ShmemSize, Stream);
115 }
116 };
117
118 template <typename Sig>
119 FunctionHandle<Sig> getFunction(const std::string &Name) {
120 if (!IsCompiled)
121 compile();
122
123 if (!isHostTargetModel(TargetModel))
124 reportFatalError("Error: getFunction() applies only to host modules");
125
126 void *FuncPtr = getFunctionAddress(Name);
127 return FunctionHandle<Sig>(*this, FuncPtr);
128 }
129
130 template <typename Sig> KernelHandle<Sig> getKernel(const std::string &Name) {
131 if (!IsCompiled)
132 compile();
133
134 if (TargetModel == TargetModelType::HOST)
135 reportFatalError("Error: getKernel() applies only to device modules");
136
137 void *FuncPtr = getFunctionAddress(Name);
138 return KernelHandle<Sig>(*this, FuncPtr);
139 }
140};
141
142} // namespace proteus
143
144#endif // PROTEUS_LLVMIRJITMODULE_H
char int void ** Args
Definition CompilerInterfaceHost.cpp:23
Definition Dispatcher.h:75
Definition LLVMIRJitModule.h:24
void setFuncAttribute(void *KernelFunc, JitFuncAttribute Attr, int Value)
Definition LLVMIRJitModule.cpp:122
TargetModelType getTargetModel() const
Definition LLVMIRJitModule.h:45
void * getFunctionAddress(const std::string &Name)
Definition LLVMIRJitModule.cpp:127
void * getKernelAddress(const std::string &Name)
Definition LLVMIRJitModule.h:49
void compile(bool Verify=false)
Definition LLVMIRJitModule.cpp:81
DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, void *KernelArgs[], uint64_t ShmemSize, void *Stream)
Definition LLVMIRJitModule.cpp:134
KernelHandle< Sig > getKernel(const std::string &Name)
Definition LLVMIRJitModule.h:130
CompiledLibrary & getLibrary()
Definition LLVMIRJitModule.h:66
FunctionHandle< Sig > getFunction(const std::string &Name)
Definition LLVMIRJitModule.h:119
Definition MemoryCache.h:27
JitFuncAttribute
Definition JitFuncAttribute.h:6
TargetModelType
Definition TargetModel.h:8
LLVMIRInputKind
Definition LLVMIRJitModule.h:18
bool isHostTargetModel(TargetModelType TargetModel)
Definition TargetModel.cpp:49
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Dispatcher.h:22
Definition CompiledLibrary.h:18
Definition Dispatcher.h:53
LLVMIRJitModule & M
Definition LLVMIRJitModule.h:79
FunctionHandle(LLVMIRJitModule &M, void *FuncPtr)
Definition LLVMIRJitModule.h:82
void * FuncPtr
Definition LLVMIRJitModule.h:80
RetT run(ArgT... Args)
Definition LLVMIRJitModule.h:85
Definition LLVMIRJitModule.h:76
auto launch(LaunchDims GridDim, LaunchDims BlockDim, uint64_t ShmemSize, void *Stream, ArgT... Args)
Definition LLVMIRJitModule.h:111
KernelHandle(LLVMIRJitModule &M, void *FuncPtr)
Definition LLVMIRJitModule.h:102
LLVMIRJitModule & M
Definition LLVMIRJitModule.h:99
void setFuncAttribute(JitFuncAttribute Attr, int Value)
Definition LLVMIRJitModule.h:107
Definition LLVMIRJitModule.h:96