1#ifndef PROTEUS_FRONTEND_FUNC_H
2#define PROTEUS_FRONTEND_FUNC_H
18template <
typename T>
class LoopBoundInfo;
19template <
typename T,
typename... ForLoopBuilders>
class LoopNestBuilder;
20template <
typename T,
typename BodyLambda>
class ForLoopBuilder;
25template <
typename T>
struct FnSig;
26template <
typename RetT_,
typename... ArgT>
struct FnSig<RetT_(ArgT...)> {
42template <
typename T,
typename BodyLambda>
64 const std::vector<IRType> &ArgTys,
bool IsKernel =
false);
72#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
77 static_assert(!std::is_array_v<T>,
"Expected non-array type");
78 static_assert(!std::is_reference_v<T>,
79 "declVar does not support reference types");
81 if constexpr (std::is_pointer_v<T>) {
92 const std::string &
Name =
"array_var") {
93 static_assert(std::is_array_v<T>,
"Expected array type");
101 return std::make_tuple(declVar<Ts>()...);
104 template <
typename... Ts,
typename... NameTs>
106 static_assert(
sizeof...(Ts) ==
sizeof...(NameTs),
107 "Number of types must match number of names");
108 return std::make_tuple(declVar<Ts>(std::forward<NameTs>(Names))...);
111 template <
typename T>
113 using RawT = std::remove_const_t<T>;
119 template <
typename T,
typename U>
121 using RawT = std::remove_const_t<T>;
127 template <
typename U>
135 typename T,
typename NameT,
136 typename = std::enable_if_t<std::is_convertible_v<NameT, std::string>>>
138 return defVar(P.first, std::string(P.second));
142 return std::make_tuple(
defVar(std::forward<ArgT>(
Args))...);
145 template <
typename T>
147 const std::string &
Name =
"run.const.var") {
152 typename T,
typename NameT,
153 typename = std::enable_if_t<std::is_convertible_v<NameT, std::string>>>
163 int Line = __builtin_LINE());
166 template <
typename BodyLambda>
167 void function(BodyLambda &&Body,
const char *File = __builtin_FILE(),
168 int Line = __builtin_LINE()) {
170 std::forward<BodyLambda>(Body)();
175 int Line = __builtin_LINE());
178 template <
typename BodyLambda>
180 const char *File = __builtin_FILE(),
181 int Line = __builtin_LINE()) {
183 std::forward<BodyLambda>(Body)();
187 template <
typename IterT,
typename InitT,
typename UpperT,
typename IncT>
190 const char *File = __builtin_FILE(),
191 int Line = __builtin_LINE(),
LoopHints Hints = {});
194 template <
typename CondLambda>
195 void beginWhile(CondLambda &&Cond,
const char *File = __builtin_FILE(),
196 int Line = __builtin_LINE());
199 template <
typename CondLambda,
typename BodyLambda>
201 const char *File = __builtin_FILE(),
202 int Line = __builtin_LINE()) {
203 beginWhile(std::forward<CondLambda>(Cond), File, Line);
204 std::forward<BodyLambda>(Body)();
208 template <
typename Sig>
209 std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
213 template <
typename Sig>
214 std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
217 template <
typename Sig,
typename... ArgVars>
218 std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
220 call(
const std::string &
Name, ArgVars &&...ArgsVars);
222 template <
typename Sig,
typename... ArgVars>
223 std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
224 call(
const std::string &
Name, ArgVars &&...ArgsVars);
226 template <
typename BuiltinFuncT>
228 using RetT = std::invoke_result_t<BuiltinFuncT &, FuncBase &>;
229 if constexpr (std::is_void_v<RetT>) {
230 std::invoke(std::forward<BuiltinFuncT>(BuiltinFunc), *
this);
232 return std::invoke(std::forward<BuiltinFuncT>(BuiltinFunc), *
this);
236 template <
typename T>
237 std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
239 template <
typename T>
240 std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
242 template <
typename T>
243 std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
245 template <
typename T>
246 std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
250 typename InitT,
typename UpperT,
typename IncT,
254 BodyLambda &&Body = {}) {
255 static_assert(is_mutable_v<IterT>,
"Loop iterator must be mutable");
258 std::forward<BodyLambda>(Body)();
262 LoopBoundInfo<IterT> BoundsInfo{Iter, Init, Upper, Inc};
263 return ForLoopBuilder<IterT, BodyLambda>(BoundsInfo, *
this,
264 std::forward<BodyLambda>(Body));
268 template <
typename... LoopBuilders>
270 using FirstBuilder = std::remove_reference_t<
271 std::tuple_element_t<0, std::tuple<LoopBuilders...>>>;
272 using T =
typename FirstBuilder::LoopIndexType;
274 *
this, std::forward<LoopBuilders>(Loops)...);
277 template <
typename T>
void ret(
const Var<T> &RetVal);
283 void setName(
const std::string &NewName);
292 return V.template convert<U>();
296template <
typename RetT,
typename... ArgT>
class Func final :
public FuncBase {
299 RetT (*CompiledFunc)(ArgT...) =
nullptr;
301 std::tuple<std::optional<Var<ArgT>>...> ArgumentsT;
304 template <
typename T, std::
size_t ArgIdx>
Var<T> createArg() {
305 auto Var = declVar<T>(
"arg." + std::to_string(ArgIdx));
306 if constexpr (std::is_pointer_v<T>) {
314 template <std::size_t... Is>
void declArgsImpl(std::index_sequence<Is...>) {
317 (std::get<Is>(ArgumentsT).emplace(createArg<ArgT, Is>()), ...);
322 template <std::size_t... Is>
auto getArgsImpl(std::index_sequence<Is...>) {
323 return std::tie(*std::get<Is>(ArgumentsT)...);
331 Dispatch(Dispatch) {}
335 void declArgs() { declArgsImpl(std::index_sequence_for<ArgT...>{}); }
337 auto getArgs() {
return getArgsImpl(std::index_sequence_for<ArgT...>{}); }
339 template <std::
size_t Idx>
auto &
getArg() {
340 return *std::get<Idx>(ArgumentsT);
346 CompiledFunc = CompiledFuncIn;
351template <
typename IterT,
typename InitT,
typename UpperT,
typename IncT>
354 const char *File,
int Line,
LoopHints Hints) {
355 static_assert(std::is_integral_v<std::remove_const_t<IterT>>,
356 "Loop iterator must be an integral type");
357 static_assert(is_mutable_v<IterT>,
"Loop iterator must be mutable");
359 CB->
beginFor(IterVar.getSlot(), IterVar.getValueType(), Init.loadValue(),
360 UpperBound.loadValue(), Inc.loadValue(),
361 std::is_signed_v<std::remove_const_t<IterT>>, File, Line, Hints);
364template <
typename CondLambda>
367 Cond)]() ->
IRValue * {
return Cond().loadValue(); },
371template <
typename Sig,
typename... ArgVars>
372std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
377 auto GetArgVal = [](
auto &&Arg) {
378 using ArgVarT = std::decay_t<
decltype(Arg)>;
379 if constexpr (std::is_pointer_v<typename ArgVarT::ValueType>)
380 return Arg.loadAddress();
382 return Arg.loadValue();
387 {GetArgVal(ArgsVars)...});
390template <
typename Sig>
391std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
398 Ret.storeValue(Call);
402template <
typename Sig>
403std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
409template <
typename... Ts>
414template <
typename Sig,
typename... ArgVars>
415std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
416 Var<typename FnSig<Sig>::RetT>>
421 auto GetArgVal = [](
auto &&Arg) {
422 using ArgVarT = std::decay_t<
decltype(Arg)>;
423 if constexpr (std::is_pointer_v<typename ArgVarT::ValueType>)
424 return Arg.loadAddress();
426 return Arg.loadValue();
430 std::vector<IRValue *> ArgVals = {GetArgVal(ArgsVars)...};
436 Ret.storeValue(Call);
441std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
443 static_assert(std::is_arithmetic_v<T>,
"atomicAdd requires arithmetic type");
446 auto Ret = declVar<T>(
"atomic.add.res.");
447 Ret.storeValue(Result);
452std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
454 static_assert(std::is_arithmetic_v<T>,
"atomicSub requires arithmetic type");
457 auto Ret = declVar<T>(
"atomic.sub.res.");
458 Ret.storeValue(Result);
463std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
465 static_assert(std::is_arithmetic_v<T>,
"atomicMax requires arithmetic type");
468 auto Ret = declVar<T>(
"atomic.max.res.");
469 Ret.storeValue(Result);
474std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
476 static_assert(std::is_arithmetic_v<T>,
"atomicMin requires arithmetic type");
479 auto Ret = declVar<T>(
"atomic.min.res.");
480 Ret.storeValue(Result);
487 IRValue *RetValue = RetVal.loadValue();
char int void ** Args
Definition CompilerInterfaceHost.cpp:23
Definition CodeBuilder.h:70
virtual VarAlloc allocPointer(const std::string &Name, IRType ElemTy, unsigned AddrSpace=0)=0
virtual void clearInsertPoint()=0
virtual IRValue * createCall(const std::string &FName, IRType RetTy, const std::vector< IRType > &ArgTys, const std::vector< IRValue * > &Args)=0
virtual void setInsertPointAtEntry()=0
virtual IRValue * createAtomicAdd(IRValue *Addr, IRValue *Val)=0
virtual void beginFor(IRValue *IterSlot, IRType IterTy, IRValue *InitVal, IRValue *UpperBoundVal, IRValue *IncVal, bool IsSigned, const char *File, int Line, LoopHints Hints={})=0
virtual VarAlloc allocScalar(const std::string &Name, IRType ValueTy)=0
virtual IRValue * createAtomicMax(IRValue *Addr, IRValue *Val)=0
virtual void beginWhile(std::function< IRValue *()> CondFn, const char *File, int Line)=0
virtual IRValue * createAtomicSub(IRValue *Addr, IRValue *Val)=0
virtual IRValue * createAtomicMin(IRValue *Addr, IRValue *Val)=0
virtual VarAlloc allocArray(const std::string &Name, AddressSpace AS, IRType ElemTy, size_t NElem)=0
virtual void createRetVoid()=0
virtual void createRet(IRValue *V)=0
Definition Dispatcher.h:75
Var< T > declVar(size_t NElem, AddressSpace AS=AddressSpace::DEFAULT, const std::string &Name="array_var")
Definition Func.h:91
std::enable_if_t< is_arithmetic_unref_v< T >, Var< T > > atomicSub(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.h:453
const std::string & getName() const
Definition Func.h:281
bool IsKernel
Definition Func.h:60
decltype(auto) callBuiltin(BuiltinFuncT &&BuiltinFunc)
Definition Func.h:227
void ifThen(const Var< bool > &CondVar, BodyLambda &&Body, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.h:179
Var< T > defVar(const T &Val, const std::string &Name="var")
Definition Func.h:112
void beginIf(const Var< bool > &CondVar, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.cpp:40
Var< T > defVar(const Var< U > &Val, const std::string &Name="var")
Definition Func.h:120
auto defVar(std::pair< T, NameT > P)
Definition Func.h:137
auto defRuntimeConsts(ArgT &&...Args)
Definition Func.h:158
bool isKernel() const
Definition Func.h:70
Var< U > defVar(const Var< U > &Val, const std::string &Name="var")
Definition Func.h:128
CodeBuilder & getCodeBuilder()
Get the underlying CodeBuilder for direct IR generation.
Definition Func.cpp:30
auto declVars(NameTs &&...Names)
Definition Func.h:105
auto defVars(ArgT &&...Args)
Definition Func.h:141
void endFunction()
Definition Func.cpp:36
void beginWhile(CondLambda &&Cond, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.h:365
void beginFunction(const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.cpp:32
IRFunction * Func
Definition Func.h:57
IRFunction * getFunction()
Definition Func.cpp:38
void whileLoop(CondLambda &&Cond, BodyLambda &&Body, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.h:200
void endIf()
Definition Func.cpp:45
CodeBuilder * CB
Definition Func.h:56
auto forLoop(Var< IterT > &Iter, const Var< InitT > &Init, const Var< UpperT > &Upper, const Var< IncT > &Inc, BodyLambda &&Body={})
Definition Func.h:252
std::string Name
Definition Func.h:55
auto buildLoopNest(LoopBuilders &&...Loops)
Definition Func.h:269
void ret()
Definition Func.h:484
JitModule & J
Definition Func.h:53
void endWhile()
Definition Func.cpp:49
JitModule & getJitModule()
Definition Func.h:47
Var< const T > defRuntimeConst(std::pair< T, NameT > P)
Definition Func.h:154
auto declVars()
Definition Func.h:100
IRValue * getArg(size_t Idx)
Definition Func.cpp:20
void beginFor(Var< IterT > &IterVar, const Var< InitT > &InitVar, const Var< UpperT > &UpperBound, const Var< IncT > &IncVar, const char *File=__builtin_FILE(), int Line=__builtin_LINE(), LoopHints Hints={})
Definition Func.h:352
Var< const T > defRuntimeConst(const T &Val, const std::string &Name="run.const.var")
Definition Func.h:146
void function(BodyLambda &&Body, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.h:167
void setFrontendName(const std::string &NewName)
Definition Func.h:286
void endFor()
Definition Func.cpp:47
auto convert(const Var< T > &V)
Definition Func.h:291
std::enable_if_t<!std::is_void_v< typename FnSig< Sig >::RetT >, Var< typename FnSig< Sig >::RetT > > call(const std::string &Name)
Definition Func.h:393
Var< T > declVar(const std::string &Name="var")
Definition Func.h:76
std::enable_if_t< is_arithmetic_unref_v< T >, Var< T > > atomicAdd(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.h:442
std::enable_if_t< is_arithmetic_unref_v< T >, Var< T > > atomicMax(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.h:464
std::enable_if_t< is_arithmetic_unref_v< T >, Var< T > > atomicMin(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.h:475
void setName(const std::string &NewName)
Definition Func.cpp:15
auto & getArg()
Definition Func.h:339
auto getCompiledFunc() const
Definition Func.h:343
Func(JitModule &J, CodeBuilder &CB, const std::string &Name, Dispatcher &Dispatch, bool IsKernel=false)
Definition Func.h:327
auto getArgs()
Definition Func.h:337
void setCompiledFunc(RetT(*CompiledFuncIn)(ArgT...))
Definition Func.h:345
RetT operator()(ArgT... Args)
Definition JitFrontend.h:178
void declArgs()
Definition Func.h:335
Definition IRFunction.h:9
Definition JitFrontend.h:19
Definition MemoryCache.h:27
AddressSpace
Definition AddressSpace.h:6
std::vector< IRType > unpackArgTypes(ArgTypeList< Ts... >)
Definition Func.h:410
void setLaunchBoundsForKernel(Function &F, int MaxThreadsPerSM, int MinBlocksPerSM=0)
Definition CoreLLVMCUDA.h:87
EmissionPolicy
Definition Func.h:35
void operator()() const
Definition Func.h:32
RetT_ RetT
Definition Func.h:28
bool Signed
Signedness of the type (meaningful for integer kinds and pointer-to-int).
Definition IRType.h:38
std::size_t NElem
Number of array elements; only meaningful when Kind == Array.
Definition IRType.h:41
IRTypeKind ElemKind
Element type kind; meaningful when Kind == Pointer or Kind == Array.
Definition IRType.h:44
Definition CodeBuilder.h:21