1#ifndef PROTEUS_JIT_DEV_HPP
2#define PROTEUS_JIT_DEV_HPP
4#include <llvm/ADT/StringRef.h>
5#include <llvm/IR/IRBuilder.h>
6#include <llvm/IR/Module.h>
7#include <llvm/IR/Verifier.h>
8#include <llvm/Support/Debug.h>
9#include <llvm/Support/MemoryBuffer.h>
10#include <llvm/TargetParser/Host.h>
11#include <llvm/TargetParser/Triple.h>
27 std::unique_ptr<LLVMContext> Ctx;
28 std::unique_ptr<Module> Mod;
29 std::unique_ptr<MemoryBuffer> CompiledObject;
31 std::deque<Func> Functions;
33 std::string TargetTriple;
36 bool IsCompiled =
false;
38 template <
typename... ArgT>
struct KernelHandle {
44 uint64_t ShmemBytes,
void *Stream, ArgT...
Args) {
46 auto Tup = std::make_tuple(
static_cast<ArgT
>(
Args)...);
49 std::array<
void *,
sizeof...(ArgT)> Ptrs;
53 ((Ptrs[I++] = (
void *)&Elts), ...);
59 return M.
launch(F, Grid, Block, Ptrs, ShmemBytes, Stream);
62 Func *operator->() {
return &F; }
66 if (Target ==
"host" || Target ==
"native") {
70 if (Target ==
"cuda") {
74 if (Target ==
"hip") {
84 return sys::getProcessTriple();
86 return "nvptx64-nvidia-cuda";
88 return "amdgcn-amd-amdhsa";
94 bool isDeviceModule() {
99 void setKernel(
Func &F) {
100 switch (TargetModel) {
102 NamedMDNode *MD = Mod->getOrInsertNamedMetadata(
"nvvm.annotations");
104 Metadata *MDVals[] = {
106 MDString::get(*Ctx,
"kernel"),
107 ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(*Ctx), 1))};
109 MD->addOperand(MDNode::get(*Ctx, MDVals));
112 F.
getFunction()->addFnAttr(Attribute::get(*Ctx,
"kernel"));
116 F.
getFunction()->setCallingConv(CallingConv::AMDGPU_KERNEL);
127 : Ctx{
std::make_unique<LLVMContext>()},
128 Mod{
std::make_unique<Module>(
"JitModule", *Ctx)},
129 TargetModel{getTargetModel(Target)},
130 TargetTriple(getTargetTriple(TargetModel)),
131 Dispatch(
Dispatcher::getDispatcher(TargetModel)) {}
140 Mod->setTargetTriple(TargetTriple);
144 Function *F = dyn_cast<Function>(FC.getCallee());
147 auto &Fn = Functions.emplace_back(FC);
149 Fn.declArgs<ArgT...>();
155 template <
typename... ArgT> KernelHandle<ArgT...>
addKernel(StringRef Name) {
156 if (!isDeviceModule())
159 Mod->setTargetTriple(TargetTriple);
163 Function *F = dyn_cast<Function>(FC.getCallee());
166 auto &Fn = Functions.emplace_back(FC);
168 Fn.declArgs<ArgT...>();
171 return KernelHandle<ArgT...>{Fn, *
this};
176 if (verifyModule(*Mod, &errs())) {
180 Dispatch.
compile(std::move(Mod));
184 template <
typename Ret,
typename... ArgT> Ret
run(
Func &F, ArgT...
Args) {
190 template <
typename... ArgT>
193 for (
auto &Fn : Functions) {
194 if (Fn.getName() == Name)
203 ArrayRef<void *> KernelArgs, uint64_t ShmemSize,
void *Stream) {
206 return Dispatch.
launch(F.
getName(), GridDim, BlockDim, KernelArgs,
211 ArrayRef<void *> KernelArgs, uint64_t ShmemSize,
void *Stream) {
215 return Dispatch.
launch(
KernelName, GridDim, BlockDim, KernelArgs, ShmemSize,
219 void print() { Mod->print(outs(),
nullptr); }
222template <
typename RetT,
typename... ArgT>
void Func::call(StringRef Name) {
224 Module &M = *F->getParent();
225 LLVMContext &Ctx = F->getContext();
228 IRB.CreateCall(Callee);
void char * KernelName
Definition CompilerInterfaceDevice.cpp:50
char int void ** Args
Definition CompilerInterfaceHost.cpp:20
TargetModelType
Definition Dispatcher.hpp:12
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:4
Definition Dispatcher.hpp:45
Ret run(StringRef FuncName, ArgT... Args)
Definition Dispatcher.hpp:57
virtual void compile(std::unique_ptr< Module > M)=0
virtual DispatchResult launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)=0
Function * getFunction()
Definition Func.cpp:58
StringRef getName()
Definition Func.hpp:117
void call(StringRef Name)
Definition JitFrontend.hpp:222
Definition JitFrontend.hpp:25
Ret run(Func &F, ArgT... Args)
Definition JitFrontend.hpp:184
auto launch(Func &F, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)
Definition JitFrontend.hpp:202
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
JitModule(StringRef Target="host")
Definition JitFrontend.hpp:126
Func & addFunction(StringRef Name)
Definition JitFrontend.hpp:139
const Module & getModule() const
Definition JitFrontend.hpp:153
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
auto launch(StringRef KernelName, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)
Definition JitFrontend.hpp:210
KernelHandle< ArgT... > addKernel(StringRef Name)
Definition JitFrontend.hpp:155
void print()
Definition JitFrontend.hpp:219
void compile(bool Verify=false)
Definition JitFrontend.hpp:174
KernelHandle< ArgT... > getKernelHandle(StringRef Name)
Definition JitFrontend.hpp:191
Definition Dispatcher.cpp:14
Definition Hashing.hpp:94
Definition Dispatcher.hpp:14
Definition TypeMap.hpp:13