1#ifndef PROTEUS_FRONTEND_FUNC_H
2#define PROTEUS_FRONTEND_FUNC_H
18template <
typename T>
class LoopBoundInfo;
19template <
typename T,
typename... ForLoopBuilders>
class LoopNestBuilder;
20template <
typename T,
typename BodyLambda>
class ForLoopBuilder;
25template <
typename T>
struct FnSig;
26template <
typename RetT_,
typename... ArgT>
struct FnSig<RetT_(ArgT...)> {
42template <
typename T,
typename BodyLambda>
61 const std::vector<IRType> &ArgTys);
67#if PROTEUS_ENABLE_CUDA || PROTEUS_ENABLE_HIP
74 static_assert(!std::is_array_v<T>,
"Expected non-array type");
75 static_assert(!std::is_reference_v<T>,
76 "declVar does not support reference types");
78 if constexpr (std::is_pointer_v<T>) {
89 const std::string &
Name =
"array_var") {
90 static_assert(std::is_array_v<T>,
"Expected array type");
98 return std::make_tuple(declVar<Ts>()...);
101 template <
typename... Ts,
typename... NameTs>
103 static_assert(
sizeof...(Ts) ==
sizeof...(NameTs),
104 "Number of types must match number of names");
105 return std::make_tuple(declVar<Ts>(std::forward<NameTs>(Names))...);
108 template <
typename T>
110 using RawT = std::remove_const_t<T>;
116 template <
typename T,
typename U>
118 using RawT = std::remove_const_t<T>;
124 template <
typename U>
132 typename T,
typename NameT,
133 typename = std::enable_if_t<std::is_convertible_v<NameT, std::string>>>
135 return defVar(P.first, std::string(P.second));
139 return std::make_tuple(
defVar(std::forward<ArgT>(
Args))...);
142 template <
typename T>
144 const std::string &
Name =
"run.const.var") {
149 typename T,
typename NameT,
150 typename = std::enable_if_t<std::is_convertible_v<NameT, std::string>>>
160 int Line = __builtin_LINE());
163 template <
typename BodyLambda>
164 void function(BodyLambda &&Body,
const char *File = __builtin_FILE(),
165 int Line = __builtin_LINE()) {
167 std::forward<BodyLambda>(Body)();
172 int Line = __builtin_LINE());
175 template <
typename BodyLambda>
177 const char *File = __builtin_FILE(),
178 int Line = __builtin_LINE()) {
180 std::forward<BodyLambda>(Body)();
184 template <
typename IterT,
typename InitT,
typename UpperT,
typename IncT>
187 const char *File = __builtin_FILE(),
188 int Line = __builtin_LINE(),
LoopHints Hints = {});
191 template <
typename CondLambda>
192 void beginWhile(CondLambda &&Cond,
const char *File = __builtin_FILE(),
193 int Line = __builtin_LINE());
196 template <
typename CondLambda,
typename BodyLambda>
198 const char *File = __builtin_FILE(),
199 int Line = __builtin_LINE()) {
200 beginWhile(std::forward<CondLambda>(Cond), File, Line);
201 std::forward<BodyLambda>(Body)();
205 template <
typename Sig>
206 std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
210 template <
typename Sig>
211 std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
214 template <
typename Sig,
typename... ArgVars>
215 std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
217 call(
const std::string &
Name, ArgVars &&...ArgsVars);
219 template <
typename Sig,
typename... ArgVars>
220 std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
221 call(
const std::string &
Name, ArgVars &&...ArgsVars);
223 template <
typename BuiltinFuncT>
225 using RetT = std::invoke_result_t<BuiltinFuncT &, FuncBase &>;
226 if constexpr (std::is_void_v<RetT>) {
227 std::invoke(std::forward<BuiltinFuncT>(BuiltinFunc), *
this);
229 return std::invoke(std::forward<BuiltinFuncT>(BuiltinFunc), *
this);
233 template <
typename T>
234 std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
236 template <
typename T>
237 std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
239 template <
typename T>
240 std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
242 template <
typename T>
243 std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
247 typename InitT,
typename UpperT,
typename IncT,
251 BodyLambda &&Body = {}) {
252 static_assert(is_mutable_v<IterT>,
"Loop iterator must be mutable");
255 std::forward<BodyLambda>(Body)();
259 LoopBoundInfo<IterT> BoundsInfo{Iter, Init, Upper, Inc};
260 return ForLoopBuilder<IterT, BodyLambda>(BoundsInfo, *
this,
261 std::forward<BodyLambda>(Body));
265 template <
typename... LoopBuilders>
267 using FirstBuilder = std::remove_reference_t<
268 std::tuple_element_t<0, std::tuple<LoopBuilders...>>>;
269 using T =
typename FirstBuilder::LoopIndexType;
271 *
this, std::forward<LoopBuilders>(Loops)...);
274 template <
typename T>
void ret(
const Var<T> &RetVal);
280 void setName(
const std::string &NewName);
286 return V.template convert<U>();
290template <
typename RetT,
typename... ArgT>
class Func final :
public FuncBase {
293 RetT (*CompiledFunc)(ArgT...) =
nullptr;
295 std::tuple<std::optional<Var<ArgT>>...> ArgumentsT;
298 template <
typename T, std::
size_t ArgIdx>
Var<T> createArg() {
299 auto Var = declVar<T>(
"arg." + std::to_string(ArgIdx));
300 if constexpr (std::is_pointer_v<T>) {
308 template <std::size_t... Is>
void declArgsImpl(std::index_sequence<Is...>) {
311 (std::get<Is>(ArgumentsT).emplace(createArg<ArgT, Is>()), ...);
316 template <std::size_t... Is>
auto getArgsImpl(std::index_sequence<Is...>) {
317 return std::tie(*std::get<Is>(ArgumentsT)...);
324 Dispatch(Dispatch) {}
328 void declArgs() { declArgsImpl(std::index_sequence_for<ArgT...>{}); }
330 auto getArgs() {
return getArgsImpl(std::index_sequence_for<ArgT...>{}); }
332 template <std::
size_t Idx>
auto &
getArg() {
333 return *std::get<Idx>(ArgumentsT);
339 CompiledFunc = CompiledFuncIn;
344template <
typename IterT,
typename InitT,
typename UpperT,
typename IncT>
347 const char *File,
int Line,
LoopHints Hints) {
348 static_assert(std::is_integral_v<std::remove_const_t<IterT>>,
349 "Loop iterator must be an integral type");
350 static_assert(is_mutable_v<IterT>,
"Loop iterator must be mutable");
352 CB->
beginFor(IterVar.getSlot(), IterVar.getValueType(), Init.loadValue(),
353 UpperBound.loadValue(), Inc.loadValue(),
354 std::is_signed_v<std::remove_const_t<IterT>>, File, Line, Hints);
357template <
typename CondLambda>
360 Cond)]() ->
IRValue * {
return Cond().loadValue(); },
364template <
typename Sig,
typename... ArgVars>
365std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
370 auto GetArgVal = [](
auto &&Arg) {
371 using ArgVarT = std::decay_t<
decltype(Arg)>;
372 if constexpr (std::is_pointer_v<typename ArgVarT::ValueType>)
373 return Arg.loadAddress();
375 return Arg.loadValue();
380 {GetArgVal(ArgsVars)...});
383template <
typename Sig>
384std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
391 Ret.storeValue(Call);
395template <
typename Sig>
396std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>,
void>
402template <
typename... Ts>
407template <
typename Sig,
typename... ArgVars>
408std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
409 Var<typename FnSig<Sig>::RetT>>
414 auto GetArgVal = [](
auto &&Arg) {
415 using ArgVarT = std::decay_t<
decltype(Arg)>;
416 if constexpr (std::is_pointer_v<typename ArgVarT::ValueType>)
417 return Arg.loadAddress();
419 return Arg.loadValue();
423 std::vector<IRValue *> ArgVals = {GetArgVal(ArgsVars)...};
429 Ret.storeValue(Call);
434std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
436 static_assert(std::is_arithmetic_v<T>,
"atomicAdd requires arithmetic type");
439 auto Ret = declVar<T>(
"atomic.add.res.");
440 Ret.storeValue(Result);
445std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
447 static_assert(std::is_arithmetic_v<T>,
"atomicSub requires arithmetic type");
450 auto Ret = declVar<T>(
"atomic.sub.res.");
451 Ret.storeValue(Result);
456std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
458 static_assert(std::is_arithmetic_v<T>,
"atomicMax requires arithmetic type");
461 auto Ret = declVar<T>(
"atomic.max.res.");
462 Ret.storeValue(Result);
467std::enable_if_t<is_arithmetic_unref_v<T>,
Var<T>>
469 static_assert(std::is_arithmetic_v<T>,
"atomicMin requires arithmetic type");
472 auto Ret = declVar<T>(
"atomic.min.res.");
473 Ret.storeValue(Result);
480 IRValue *RetValue = RetVal.loadValue();
char int void ** Args
Definition CompilerInterfaceHost.cpp:22
Definition CodeBuilder.h:66
virtual VarAlloc allocPointer(const std::string &Name, IRType ElemTy, unsigned AddrSpace=0)=0
virtual void clearInsertPoint()=0
virtual IRValue * createCall(const std::string &FName, IRType RetTy, const std::vector< IRType > &ArgTys, const std::vector< IRValue * > &Args)=0
virtual void setInsertPointAtEntry()=0
virtual IRValue * createAtomicAdd(IRValue *Addr, IRValue *Val)=0
virtual void beginFor(IRValue *IterSlot, IRType IterTy, IRValue *InitVal, IRValue *UpperBoundVal, IRValue *IncVal, bool IsSigned, const char *File, int Line, LoopHints Hints={})=0
virtual VarAlloc allocScalar(const std::string &Name, IRType ValueTy)=0
virtual IRValue * createAtomicMax(IRValue *Addr, IRValue *Val)=0
virtual void beginWhile(std::function< IRValue *()> CondFn, const char *File, int Line)=0
virtual IRValue * createAtomicSub(IRValue *Addr, IRValue *Val)=0
virtual IRValue * createAtomicMin(IRValue *Addr, IRValue *Val)=0
virtual VarAlloc allocArray(const std::string &Name, AddressSpace AS, IRType ElemTy, size_t NElem)=0
virtual void createRetVoid()=0
virtual void createRet(IRValue *V)=0
Definition Dispatcher.h:74
Var< T > declVar(size_t NElem, AddressSpace AS=AddressSpace::DEFAULT, const std::string &Name="array_var")
Definition Func.h:88
std::enable_if_t< is_arithmetic_unref_v< T >, Var< T > > atomicSub(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.h:446
const std::string & getName() const
Definition Func.h:278
decltype(auto) callBuiltin(BuiltinFuncT &&BuiltinFunc)
Definition Func.h:224
void ifThen(const Var< bool > &CondVar, BodyLambda &&Body, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.h:176
Var< T > defVar(const T &Val, const std::string &Name="var")
Definition Func.h:109
void beginIf(const Var< bool > &CondVar, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.cpp:41
Var< T > defVar(const Var< U > &Val, const std::string &Name="var")
Definition Func.h:117
auto defVar(std::pair< T, NameT > P)
Definition Func.h:134
auto defRuntimeConsts(ArgT &&...Args)
Definition Func.h:155
Var< U > defVar(const Var< U > &Val, const std::string &Name="var")
Definition Func.h:125
CodeBuilder & getCodeBuilder()
Get the underlying CodeBuilder for direct IR generation.
Definition Func.cpp:31
auto declVars(NameTs &&...Names)
Definition Func.h:102
auto defVars(ArgT &&...Args)
Definition Func.h:138
void endFunction()
Definition Func.cpp:37
void beginWhile(CondLambda &&Cond, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.h:358
void beginFunction(const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.cpp:33
IRFunction * Func
Definition Func.h:57
IRFunction * getFunction()
Definition Func.cpp:39
void whileLoop(CondLambda &&Cond, BodyLambda &&Body, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.h:197
void endIf()
Definition Func.cpp:46
CodeBuilder * CB
Definition Func.h:56
auto forLoop(Var< IterT > &Iter, const Var< InitT > &Init, const Var< UpperT > &Upper, const Var< IncT > &Inc, BodyLambda &&Body={})
Definition Func.h:249
std::string Name
Definition Func.h:55
auto buildLoopNest(LoopBuilders &&...Loops)
Definition Func.h:266
void ret()
Definition Func.h:477
JitModule & J
Definition Func.h:53
void endWhile()
Definition Func.cpp:50
JitModule & getJitModule()
Definition Func.h:47
Var< const T > defRuntimeConst(std::pair< T, NameT > P)
Definition Func.h:151
auto declVars()
Definition Func.h:97
IRValue * getArg(size_t Idx)
Definition Func.cpp:19
void beginFor(Var< IterT > &IterVar, const Var< InitT > &InitVar, const Var< UpperT > &UpperBound, const Var< IncT > &IncVar, const char *File=__builtin_FILE(), int Line=__builtin_LINE(), LoopHints Hints={})
Definition Func.h:345
Var< const T > defRuntimeConst(const T &Val, const std::string &Name="run.const.var")
Definition Func.h:143
void function(BodyLambda &&Body, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.h:164
void endFor()
Definition Func.cpp:48
auto convert(const Var< T > &V)
Definition Func.h:285
std::enable_if_t<!std::is_void_v< typename FnSig< Sig >::RetT >, Var< typename FnSig< Sig >::RetT > > call(const std::string &Name)
Definition Func.h:386
Var< T > declVar(const std::string &Name="var")
Definition Func.h:73
std::enable_if_t< is_arithmetic_unref_v< T >, Var< T > > atomicAdd(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.h:435
std::enable_if_t< is_arithmetic_unref_v< T >, Var< T > > atomicMax(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.h:457
std::enable_if_t< is_arithmetic_unref_v< T >, Var< T > > atomicMin(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.h:468
void setName(const std::string &NewName)
Definition Func.cpp:14
auto & getArg()
Definition Func.h:332
auto getCompiledFunc() const
Definition Func.h:336
auto getArgs()
Definition Func.h:330
void setCompiledFunc(RetT(*CompiledFuncIn)(ArgT...))
Definition Func.h:338
RetT operator()(ArgT... Args)
Definition JitFrontend.h:178
Func(JitModule &J, CodeBuilder &CB, const std::string &Name, Dispatcher &Dispatch)
Definition Func.h:321
void declArgs()
Definition Func.h:328
Definition IRFunction.h:9
Definition JitFrontend.h:18
Definition MemoryCache.h:26
AddressSpace
Definition AddressSpace.h:6
std::vector< IRType > unpackArgTypes(ArgTypeList< Ts... >)
Definition Func.h:403
void setLaunchBoundsForKernel(Function &F, int MaxThreadsPerSM, int MinBlocksPerSM=0)
Definition CoreLLVMCUDA.h:87
EmissionPolicy
Definition Func.h:35
void operator()() const
Definition Func.h:32
RetT_ RetT
Definition Func.h:28
bool Signed
Signedness of the type (meaningful for integer kinds and pointer-to-int).
Definition IRType.h:38
std::size_t NElem
Number of array elements; only meaningful when Kind == Array.
Definition IRType.h:41
IRTypeKind ElemKind
Element type kind; meaningful when Kind == Pointer or Kind == Array.
Definition IRType.h:44
Definition CodeBuilder.h:21