1#ifndef PROTEUS_JIT_DEV_HPP
2#define PROTEUS_JIT_DEV_HPP
4#include <llvm/ADT/StringRef.h>
5#include <llvm/Bitcode/BitcodeWriter.h>
6#include <llvm/IR/IRBuilder.h>
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Verifier.h>
9#include <llvm/Support/Debug.h>
10#include <llvm/Support/MemoryBuffer.h>
11#include <llvm/TargetParser/Host.h>
12#include <llvm/TargetParser/Triple.h>
28 std::unique_ptr<LLVMContext> Ctx;
29 std::unique_ptr<Module> Mod;
30 std::unique_ptr<MemoryBuffer> ObjectModule;
32 std::deque<std::unique_ptr<FuncBase>> Functions;
34 std::string TargetTriple;
38 bool IsCompiled =
false;
40 template <
typename... ArgT>
struct KernelHandle {
41 Func<void, ArgT...> &F;
46 uint64_t ShmemBytes,
void *Stream, ArgT...
Args) {
48 void *Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
53 auto GetKernelFunc = [&]() {
56 if (
auto KernelFunc = F.getCompiledFunc()) {
63 auto KernelFunc =
reinterpret_cast<decltype(F.getCompiledFunc())
>(
66 F.setCompiledFunc(KernelFunc);
71 return M.Dispatch.
launch(
reinterpret_cast<void *
>(GetKernelFunc()), Grid,
72 Block, Ptrs, ShmemBytes, Stream);
75 FuncBase *operator->() {
return &F; }
78 bool isDeviceModule() {
84 switch (TargetModel) {
86 NamedMDNode *MD = Mod->getOrInsertNamedMetadata(
"nvvm.annotations");
88 Metadata *MDVals[] = {
90 MDString::get(*Ctx,
"kernel"),
91 ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(*Ctx), 1))};
93 MD->addOperand(MDNode::get(*Ctx, MDVals));
96 F.
getFunction()->addFnAttr(Attribute::get(*Ctx,
"kernel"));
100 F.
getFunction()->setCallingConv(CallingConv::AMDGPU_KERNEL);
111 : Ctx{
std::make_unique<LLVMContext>()},
112 Mod{
std::make_unique<Module>(
"JitModule", *Ctx)},
123 template <
typename RetT,
typename... ArgT>
127 "The module is compiled, no further code can be added");
129 Mod->setTargetTriple(TargetTriple);
133 Function *F = dyn_cast<Function>(FC.getCallee());
136 auto TypedFn = std::make_unique<
Func<RetT, ArgT...>>(*
this, FC, Dispatch);
137 Func<RetT, ArgT...> &TypedFnRef = *TypedFn;
138 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(TypedFn));
148 template <
typename... ArgT> KernelHandle<ArgT...>
addKernel(StringRef Name) {
151 "The module is compiled, no further code can be added");
153 if (!isDeviceModule())
156 Mod->setTargetTriple(TargetTriple);
160 Function *F = dyn_cast<Function>(FC.getCallee());
163 auto TypedFn = std::make_unique<
Func<void, ArgT...>>(*
this, FC, Dispatch);
164 Func<void, ArgT...> &TypedFnRef = *TypedFn;
165 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(TypedFn));
170 return KernelHandle<ArgT...>{TypedFnRef, *
this};
178 if (verifyModule(*Mod, &errs())) {
182 SmallVector<char, 0> Buffer;
183 raw_svector_ostream OS(Buffer);
184 WriteBitcodeToFile(*Mod, OS);
191 ModuleHash =
hash(StringRef{Buffer.data(), Buffer.size()});
192 for (
auto &JitF : Functions) {
193 JitF->setName(JitF->getName().str() +
"$" + ModuleHash.
toString());
201 ObjectModule = Dispatch.
compile(std::move(Ctx), std::move(Mod), ModuleHash);
213 return ObjectModule->getMemBufferRef();
220 void print() { Mod->print(outs(),
nullptr); }
223template <
typename RetT,
typename... ArgT>
226 Module &M = *F->getParent();
227 LLVMContext &Ctx = F->getContext();
231 auto *Call =
IRB.CreateCall(Callee);
237template <
typename RetT,
typename... ArgT>
240 Module &M = *F->getParent();
241 LLVMContext &Ctx = F->getContext();
244 IRB.CreateCall(Callee);
247template <
typename RetT,
typename... ArgT>
254 "Target is a GPU model, cannot directly run functions, use launch()");
256 return Dispatch.
run<RetT>(getName(), J.getObjectModuleRef(),
Args...);
char int void ** Args
Definition CompilerInterfaceHost.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition Dispatcher.hpp:46
virtual std::unique_ptr< MemoryBuffer > lookupObjectModule(HashT ModuleHash)=0
virtual std::unique_ptr< MemoryBuffer > compile(std::unique_ptr< LLVMContext > Ctx, std::unique_ptr< Module > M, HashT ModuleHash)=0
virtual DispatchResult launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream, std::optional< MemoryBufferRef > ObjectModule)=0
auto run(StringRef FuncName, std::optional< MemoryBufferRef > ObjectModule, ArgT &&...Args)
Definition Dispatcher.hpp:75
virtual void * getFunctionAddress(StringRef FunctionName, std::optional< MemoryBufferRef > ObjectModule)=0
void declArgs()
Definition Func.hpp:109
std::enable_if_t<!std::is_void_v< RetT >, Var & > call(StringRef Name)
Definition JitFrontend.hpp:224
Var & declVarInternal(StringRef Name, Type *Ty, Type *PointerElemType=nullptr)
Definition Func.cpp:18
IRBuilder IRB
Definition Func.hpp:26
Function * getFunction()
Definition Func.cpp:60
std::string Name
Definition Func.hpp:31
RetT operator()(ArgT... Args)
Definition JitFrontend.hpp:248
Definition Hashing.hpp:20
std::string toString() const
Definition Hashing.hpp:28
Definition JitFrontend.hpp:26
const Dispatcher & getDispatcher() const
Definition JitFrontend.hpp:216
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
bool isCompiled() const
Definition JitFrontend.hpp:144
JitModule(StringRef Target="host")
Definition JitFrontend.hpp:110
Func< RetT, ArgT... > & addFunction(StringRef Name)
Definition JitFrontend.hpp:124
TargetModelType getTargetModel() const
Definition JitFrontend.hpp:218
std::optional< MemoryBufferRef > getObjectModuleRef() const
Definition JitFrontend.hpp:207
const Module & getModule() const
Definition JitFrontend.hpp:146
HashT getModuleHash() const
Definition JitFrontend.hpp:205
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
KernelHandle< ArgT... > addKernel(StringRef Name)
Definition JitFrontend.hpp:148
void print()
Definition JitFrontend.hpp:220
void compile(bool Verify=false)
Definition JitFrontend.hpp:173
Definition CppJitModule.cpp:21
TargetModelType
Definition TargetModel.hpp:14
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
TargetModelType parseTargetModel(StringRef Target)
Definition TargetModel.hpp:16
std::string getTargetTriple(TargetModelType Model)
Definition TargetModel.hpp:32
Definition Hashing.hpp:147
Definition Dispatcher.hpp:15
Definition TypeMap.hpp:13
void storeValue(Value *Val)
Definition Var.cpp:102