1#ifndef PROTEUS_JIT_DEV_HPP
2#define PROTEUS_JIT_DEV_HPP
4#include <llvm/ADT/StringRef.h>
5#include <llvm/Bitcode/BitcodeWriter.h>
6#include <llvm/IR/IRBuilder.h>
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Verifier.h>
9#include <llvm/Support/Debug.h>
10#include <llvm/Support/MemoryBuffer.h>
11#include <llvm/TargetParser/Host.h>
12#include <llvm/TargetParser/Triple.h>
30 std::unique_ptr<LLVMContext> Ctx;
31 std::unique_ptr<Module> Mod;
32 std::unique_ptr<CompiledLibrary> Library;
34 std::deque<std::unique_ptr<FuncBase>> Functions;
36 std::string TargetTriple;
40 bool IsCompiled =
false;
42 template <
typename...
ArgT>
struct KernelHandle;
44 template <
typename RetT,
typename...
ArgT>
47 auto TypedFn = std::make_unique<
Func<RetT,
ArgT...>>(*
this, FC, Dispatch);
49 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(
TypedFn));
54 template <
typename...
ArgT>
59 std::unique_ptr<FuncBase> &Fn = Functions.emplace_back(std::move(
TypedFn));
67 template <
typename...
ArgT>
struct KernelHandle {
73 if (!M.isDeviceModule())
79#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
94 void *
Ptrs[
sizeof...(ArgT)] = {(
void *)&
Args...};
122 FuncBase *operator->() {
return &F; }
125 bool isDeviceModule() {
131 switch (TargetModel) {
133 NamedMDNode *
MD = Mod->getOrInsertNamedMetadata(
"nvvm.annotations");
137 MDString::get(*Ctx,
"kernel"),
138 ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(*Ctx), 1))};
140 MD->addOperand(MDNode::get(*Ctx,
MDVals));
143 F.
getFunction()->addFnAttr(Attribute::get(*Ctx,
"kernel"));
147 F.
getFunction()->setCallingConv(CallingConv::AMDGPU_KERNEL);
176 "The module is compiled, no further code can be added");
178 Mod->setTargetTriple(TargetTriple);
194 static_assert(std::is_void_v<RetT>,
"Kernels must have void return type");
199 "The module is compiled, no further code can be added");
201 if (!isDeviceModule())
204 Mod->setTargetTriple(TargetTriple);
210 return buildKernelFromArgsList(FC,
ArgT{});
230 for (
auto &
JitF : Functions) {
239 Library = std::make_unique<CompiledLibrary>(
240 Dispatch.
compile(std::move(Ctx), std::move(Mod), ModuleHash));
260 template <
typename RetT,
typename...
ArgT>
266 void print() { Mod->print(outs(),
nullptr); }
269template <
typename Sig>
270std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
Var &>
285template <
typename Sig>
286std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
297std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
Var &>
313std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
322template <
typename RetT,
typename...
ArgT>
328 CompiledFunc =
reinterpret_cast<decltype(CompiledFunc)
>(
329 J.getDispatcher().getFunctionAddress(getName(), J.getModuleHash(),
335 "Target is a GPU model, cannot directly run functions, use launch()");
337 if constexpr (std::is_void_v<RetT>)
338 Dispatch.run<RetT(
ArgT...)>(
reinterpret_cast<void *
>(CompiledFunc),
341 return Dispatch.run<RetT(
ArgT...)>(
reinterpret_cast<void *
>(CompiledFunc),
char int void ** Args
Definition CompilerInterfaceHost.cpp:21
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition Dispatcher.hpp:54
virtual std::unique_ptr< CompiledLibrary > lookupCompiledLibrary(HashT ModuleHash)=0
virtual std::unique_ptr< MemoryBuffer > compile(std::unique_ptr< LLVMContext > Ctx, std::unique_ptr< Module > M, HashT ModuleHash, bool DisableIROpt=false)=0
virtual void * getFunctionAddress(StringRef FunctionName, HashT ModuleHash, CompiledLibrary &Library)=0
virtual DispatchResult launch(void *KernelFunc, LaunchDims GridDim, LaunchDims BlockDim, ArrayRef< void * > KernelArgs, uint64_t ShmemSize, void *Stream)=0
void declArgs()
Definition Func.hpp:147
Var & declVarInternal(StringRef Name, Type *Ty, Type *PointerElemType=nullptr)
Definition Func.cpp:28
std::enable_if_t<!std::is_void_v< typename FnSig< Sig >::RetT >, Var & > call(StringRef Name)
Definition JitFrontend.hpp:271
IRBuilder IRB
Definition Func.hpp:43
Function * getFunction()
Definition Func.cpp:78
std::string Name
Definition Func.hpp:50
JitModule & J
Definition Func.hpp:41
RetT operator()(ArgT... Args)
Definition JitFrontend.hpp:323
Definition Hashing.hpp:20
std::string toString() const
Definition Hashing.hpp:28
Definition JitFrontend.hpp:28
JitModule & operator=(JitModule &&)=delete
JitModule(const JitModule &)=delete
bool isCompiled() const
Definition JitFrontend.hpp:188
JitModule(StringRef Target="host")
Definition JitFrontend.hpp:157
TargetModelType getTargetModel() const
Definition JitFrontend.hpp:248
FunctionCallee getFunctionCallee(StringRef Name, ArgTypeList< ArgT... >)
Definition JitFrontend.hpp:261
const Module & getModule() const
Definition JitFrontend.hpp:190
HashT getModuleHash() const
Definition JitFrontend.hpp:244
JitModule(JitModule &&)=delete
JitModule & operator=(const JitModule &)=delete
CompiledLibrary & getLibrary()
Definition JitFrontend.hpp:250
Dispatcher & getDispatcher() const
Definition JitFrontend.hpp:246
auto & addFunction(StringRef Name)
Definition JitFrontend.hpp:170
auto addKernel(StringRef Name)
Definition JitFrontend.hpp:192
void print()
Definition JitFrontend.hpp:266
void compile(bool Verify=false)
Definition JitFrontend.hpp:213
Definition BuiltinsCUDA.cpp:4
TargetModelType
Definition TargetModel.hpp:14
void setLaunchBoundsForKernel(Function &F, int MaxThreadsPerSM, int MinBlocksPerSM=0)
Definition CoreLLVMCUDA.hpp:87
HashT hash(FirstT &&First, RestTs &&...Rest)
Definition Hashing.hpp:126
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
TargetModelType parseTargetModel(StringRef Target)
Definition TargetModel.hpp:16
std::string getTargetTriple(TargetModelType Model)
Definition TargetModel.hpp:40
Definition Hashing.hpp:147
Definition Dispatcher.hpp:16
Definition CompiledLibrary.hpp:12
Definition TypeMap.hpp:13
virtual void storeValue(Value *Val)=0