Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
Func.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_FUNC_HPP
2#define PROTEUS_FRONTEND_FUNC_HPP
3
4#include <initializer_list>
5#include <memory>
6
7#include <llvm/IR/IRBuilder.h>
8#include <llvm/IR/Module.h>
9
11#include "proteus/Error.h"
17
18namespace proteus {
19
20class JitModule;
21template <typename T> class LoopBoundInfo;
22template <typename T, typename... ForLoopBuilders> class LoopNestBuilder;
23template <typename T, typename BodyLambda> class ForLoopBuilder;
24
25// Helper struct to represent the signature of a function.
26// Useful to partially-specialize function templates.
27template <typename... ArgTs> struct ArgTypeList {};
28template <typename T> struct FnSig;
29template <typename RetT_, typename... ArgT> struct FnSig<RetT_(ArgT...)> {
31 using RetT = RetT_;
32};
33
34using namespace llvm;
35
37 void operator()() const {}
38};
39
40class FuncBase {
41protected:
45 IRBuilderBase::InsertPoint IP;
46
47 std::string Name;
48
49 enum class ScopeKind { FUNCTION, IF, FOR, WHILE };
50 struct Scope {
51 std::string File;
52 int Line;
54 IRBuilderBase::InsertPoint ContIP;
55
56 explicit Scope(const char *File, int Line, ScopeKind Kind,
57 IRBuilderBase::InsertPoint ContIP)
59 };
60 std::vector<Scope> Scopes;
61
62 std::string toString(ScopeKind Kind) {
63 switch (Kind) {
65 return "FUNCTION";
66 case ScopeKind::IF:
67 return "IF";
68 case ScopeKind::FOR:
69 return "FOR";
71 return "WHILE";
72 default:
73 PROTEUS_FATAL_ERROR("Unsupported Kind " +
74 std::to_string(static_cast<int>(Kind)));
75 }
76 }
77
78 template <typename T>
79 Var<T> emitAtomic(AtomicRMWInst::BinOp Op, const Var<T *> &Addr,
80 const Var<T> &Val);
81
82public:
84
86
88
91
93
95
96 template <typename T> Var<T> declVar(StringRef Name = "var") {
97 static_assert(!std::is_array_v<T>, "Expected non-array type");
98
99 Function *F = getFunction();
100 auto &Ctx = F->getContext();
101 Type *AllocaTy = TypeMap<T>::get(Ctx);
102 auto *Alloca = emitAlloca(AllocaTy, Name);
103
104 if constexpr (std::is_pointer_v<T>) {
106 return Var<T>(std::make_unique<PointerStorage>(Alloca, IRB, PtrElemTy),
107 *this);
108 } else {
109 return Var<T>(std::make_unique<ScalarStorage>(Alloca, IRB), *this);
110 }
111 }
112
113 template <typename T>
115 StringRef Name = "array_var") {
116 static_assert(std::is_array_v<T>, "Expected array type");
117
118 Function *F = getFunction();
119 auto *BasePointer =
120 emitArrayCreate(TypeMap<T>::get(F->getContext(), NElem), AS, Name);
121
122 auto *ArrTy = cast<ArrayType>(TypeMap<T>::get(F->getContext(), NElem));
123 return Var<T>(std::make_unique<ArrayStorage>(BasePointer, IRB, ArrTy),
124 *this);
125 }
126
127 template <typename T> Var<T> defVar(const T &Val, StringRef Name = "var") {
129 Var = Val;
130 return Var;
131 }
132
133 template <typename T, typename U>
135 auto Res = declVar<T>(Name);
136 Res = Var;
137 return Res;
138 }
139
140 template <typename T>
141 Var<T> defRuntimeConst(const T &Val, StringRef Name = "run.const.var") {
142 return defVar<T>(Val, Name);
143 }
144
145 template <typename... ArgT> auto defRuntimeConsts(ArgT &&...Args) {
146 return std::make_tuple(defRuntimeConst(std::forward<ArgT>(Args))...);
147 }
148
149 void beginFunction(const char *File = __builtin_FILE(),
150 int Line = __builtin_LINE());
151 void endFunction();
152
153 void beginIf(const Var<bool> &CondVar, const char *File = __builtin_FILE(),
154 int Line = __builtin_LINE());
155 void endIf();
156
157 template <typename T>
158 void beginFor(Var<T> &IterVar, const Var<T> &InitVar,
159 const Var<T> &UpperBound, const Var<T> &IncVar,
160 const char *File = __builtin_FILE(),
161 int Line = __builtin_LINE());
162 void endFor();
163
164 template <typename CondLambda>
165 void beginWhile(CondLambda &&Cond, const char *File = __builtin_FILE(),
166 int Line = __builtin_LINE());
167 void endWhile();
168
169 template <typename Sig>
170 std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
173
174 template <typename Sig>
175 std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>, void>
177
178 template <typename Sig, typename... ArgVars>
179 std::enable_if_t<!std::is_void_v<typename FnSig<Sig>::RetT>,
182
183 template <typename Sig, typename... ArgVars>
184 std::enable_if_t<std::is_void_v<typename FnSig<Sig>::RetT>, void>
186
187 template <typename BuiltinFuncT>
189 using RetT = std::invoke_result_t<BuiltinFuncT &, FuncBase &>;
190 if constexpr (std::is_void_v<RetT>) {
191 std::invoke(std::forward<BuiltinFuncT>(BuiltinFunc), *this);
192 } else {
193 return std::invoke(std::forward<BuiltinFuncT>(BuiltinFunc), *this);
194 }
195 }
196
197 template <typename T>
198 std::enable_if_t<std::is_arithmetic_v<T>, Var<T>>
199 atomicAdd(const Var<T *> &Addr, const Var<T> &Val);
200 template <typename T>
201 std::enable_if_t<std::is_arithmetic_v<T>, Var<T>>
202 atomicSub(const Var<T *> &Addr, const Var<T> &Val);
203 template <typename T>
204 std::enable_if_t<std::is_arithmetic_v<T>, Var<T>>
205 atomicMax(const Var<T *> &Addr, const Var<T> &Val);
206 template <typename T>
207 std::enable_if_t<std::is_arithmetic_v<T>, Var<T>>
208 atomicMin(const Var<T *> &Addr, const Var<T> &Val);
209
210 template <typename T, typename BodyLambda = EmptyLambda>
211 auto forLoop(std::initializer_list<Var<T>> Bounds, BodyLambda &&Body = {}) {
212 auto It = Bounds.begin();
213 LoopBoundInfo<T> BoundsInfo{It[0], It[1], It[2], It[3]};
215 std::forward<BodyLambda>(Body));
216 }
217
218 template <typename... LoopBuilders>
219 auto buildLoopNest(LoopBuilders &&...Loops) {
220 using FirstBuilder = std::remove_reference_t<
221 std::tuple_element_t<0, std::tuple<LoopBuilders...>>>;
222 using T = typename FirstBuilder::LoopIndexType;
224 *this, std::forward<LoopBuilders>(Loops)...);
225 }
226
227 template <typename T> void ret(const Var<T> &RetVal);
228
229 void ret();
230
231 StringRef getName() const { return Name; }
232
234 Name = NewName.str();
235 Function *F = getFunction();
236 F->setName(Name);
237 }
238
239 // Convert the given Var's value to type U and return a new Var holding
240 // the converted value.
241 template <typename U, typename T>
242 std::enable_if_t<std::is_convertible_v<T, U>, Var<U>>
243 convert(const Var<T> &V) {
244 auto &IRBRef = getIRBuilder();
245 Var<U> Res = declVar<U>("convert.");
246 Value *Converted = proteus::convert<T, U>(IRBRef, V.loadValue());
247 Res.storeValue(Converted);
248 return Res;
249 }
250};
251
252template <typename RetT, typename... ArgT> class Func final : public FuncBase {
253private:
254 Dispatcher &Dispatch;
255 RetT (*CompiledFunc)(ArgT...) = nullptr;
256 // Optional because Var<ArgT> is not default constructible.
257 std::tuple<std::optional<Var<ArgT>>...> ArgumentsT;
258
259private:
260 template <typename T, std::size_t ArgIdx> Var<T> createArg() {
261 Function *F = getFunction();
262 auto Var = declVar<T>("arg." + std::to_string(ArgIdx));
263 if constexpr (std::is_pointer_v<T>) {
264 Var.storePointer(F->getArg(ArgIdx));
265 } else {
266 Var.storeValue(F->getArg(ArgIdx));
267 }
268 return Var;
269 }
270
271 template <std::size_t... Is> void declArgsImpl(std::index_sequence<Is...>) {
272 Function *F = getFunction();
273 auto &EntryBB = F->getEntryBlock();
274 IP = IRBuilderBase::InsertPoint(&EntryBB, EntryBB.end());
275 IRB.restoreIP(IP);
276
277 (std::get<Is>(ArgumentsT).emplace(createArg<ArgT, Is>()), ...);
278
279 IRB.ClearInsertionPoint();
280 }
281
282 template <std::size_t... Is> auto getArgsImpl(std::index_sequence<Is...>) {
283 return std::tie(*std::get<Is>(ArgumentsT)...);
284 }
285
286public:
288 : FuncBase(J, FC), Dispatch(Dispatch) {}
289
290 RetT operator()(ArgT... Args);
291
292 void declArgs() { declArgsImpl(std::index_sequence_for<ArgT...>{}); }
293
294 auto getArgs() { return getArgsImpl(std::index_sequence_for<ArgT...>{}); }
295
296 template <std::size_t Idx> auto &getArg() {
297 return *std::get<Idx>(ArgumentsT);
298 }
299
300 auto getCompiledFunc() const { return CompiledFunc; }
301
303 CompiledFunc = CompiledFuncIn;
304 }
305};
306
307// beginFor implementation
308template <typename T>
309void FuncBase::beginFor(Var<T> &IterVar, const Var<T> &Init,
310 const Var<T> &UpperBound, const Var<T> &Inc,
311 const char *File, int Line) {
312 static_assert(std::is_integral_v<T>,
313 "Loop iterator must be an integral type");
314
315 Function *F = getFunction();
316 // Update the terminator of the current basic block due to the split
317 // control-flow.
318 BasicBlock *CurBlock = IP.getBlock();
320 CurBlock->splitBasicBlock(IP.getPoint(), CurBlock->getName() + ".split");
321
322 auto ContIP = IRBuilderBase::InsertPoint(NextBlock, NextBlock->begin());
323 Scopes.emplace_back(File, Line, ScopeKind::FOR, ContIP);
324
326 BasicBlock::Create(F->getContext(), "loop.header", F, NextBlock);
328 BasicBlock::Create(F->getContext(), "loop.cond", F, NextBlock);
329 BasicBlock *Body =
330 BasicBlock::Create(F->getContext(), "loop.body", F, NextBlock);
332 BasicBlock::Create(F->getContext(), "loop.inc", F, NextBlock);
334 BasicBlock::Create(F->getContext(), "loop.end", F, NextBlock);
335
336 // Erase the old terminator and branch to the header.
337 CurBlock->getTerminator()->eraseFromParent();
338 IRB.SetInsertPoint(CurBlock);
339 { IRB.CreateBr(Header); }
340
341 IRB.SetInsertPoint(Header);
342 {
343 IterVar = Init;
344 IRB.CreateBr(LoopCond);
345 }
346
347 IRB.SetInsertPoint(LoopCond);
348 {
349 auto CondVar = IterVar < UpperBound;
350 Value *Cond = CondVar.loadValue();
351 IRB.CreateCondBr(Cond, Body, LoopExit);
352 }
353
354 IRB.SetInsertPoint(Body);
355 IRB.CreateBr(Latch);
356
357 IRB.SetInsertPoint(Latch);
358 {
359 IterVar = IterVar + Inc;
360 IRB.CreateBr(LoopCond);
361 }
362
363 IRB.SetInsertPoint(LoopExit);
364 { IRB.CreateBr(NextBlock); }
365
366 IP = IRBuilderBase::InsertPoint(Body, Body->begin());
367 IRB.restoreIP(IP);
368}
369
370template <typename CondLambda>
371void FuncBase::beginWhile(CondLambda &&Cond, const char *File, int Line) {
372 Function *F = getFunction();
373 // Update the terminator of the current basic block due to the split
374 // control-flow.
375 BasicBlock *CurBlock = IP.getBlock();
377 CurBlock->splitBasicBlock(IP.getPoint(), CurBlock->getName() + ".split");
378
379 auto ContIP = IRBuilderBase::InsertPoint(NextBlock, NextBlock->begin());
380 Scopes.emplace_back(File, Line, ScopeKind::WHILE, ContIP);
381
383 BasicBlock::Create(F->getContext(), "while.cond", F, NextBlock);
384 BasicBlock *Body =
385 BasicBlock::Create(F->getContext(), "while.body", F, NextBlock);
387 BasicBlock::Create(F->getContext(), "while.end", F, NextBlock);
388
389 CurBlock->getTerminator()->eraseFromParent();
390 IRB.SetInsertPoint(CurBlock);
391 { IRB.CreateBr(LoopCond); }
392
393 IRB.SetInsertPoint(LoopCond);
394 {
395 auto CondVar = Cond();
396 Value *CondV = CondVar.loadValue();
397 IRB.CreateCondBr(CondV, Body, LoopExit);
398 }
399
400 IRB.SetInsertPoint(Body);
401 IRB.CreateBr(LoopCond);
402
403 IRB.SetInsertPoint(LoopExit);
404 { IRB.CreateBr(NextBlock); }
405
406 IP = IRBuilderBase::InsertPoint(Body, Body->begin());
407 IRB.restoreIP(IP);
408}
409
410// Var implementations (defined here after FuncBase is
411// complete) so we have it available.
412
413// Helper function for binary operations on Var types
414template <typename T, typename U, typename IntOp, typename FPOp>
416 FPOp FOp) {
417 FuncBase &Fn = L.Fn;
418 if (&Fn != &R.Fn)
419 PROTEUS_FATAL_ERROR("Variables should belong to the same function");
420
421 auto &IRB = Fn.getIRBuilder();
422
423 Value *LHS = L.loadValue();
424 Value *RHS = R.loadValue();
425
426 using CommonT = std::common_type_t<T, U>;
429
430 Value *Result = nullptr;
431 if constexpr (std::is_integral_v<CommonT>) {
432 Result = IOp(IRB, LHS, RHS);
433 } else {
434 Result = FOp(IRB, LHS, RHS);
435 }
436
437 auto ResultVar = Fn.declVar<std::common_type_t<T, U>>("res.");
438 ResultVar.storeValue(Result);
439
440 return ResultVar;
441}
442
443// Helper function for compound assignment with a constant
444template <typename T, typename U, typename IntOp, typename FPOp>
446compoundAssignConst(Var<T, std::enable_if_t<std::is_arithmetic_v<T>>> &LHS,
447 const U &ConstValue, IntOp IOp, FPOp FOp) {
448 static_assert(std::is_convertible_v<U, T>, "U must be convertible to T");
449 auto &IRB = LHS.Fn.getIRBuilder();
450
451 using CleanU = std::remove_cv_t<std::remove_reference_t<U>>;
452
453 Function *Function = LHS.Fn.getFunction();
454 auto &Ctx = Function->getContext();
455 Type *RHSType = TypeMap<CleanU>::get(Ctx);
456
457 Value *RHS = nullptr;
458 if constexpr (std::is_integral_v<CleanU>) {
459 RHS = ConstantInt::get(RHSType, ConstValue);
460 } else {
461 RHS = ConstantFP::get(RHSType, ConstValue);
462 }
463
464 Value *LHSVal = LHS.loadValue();
465
466 RHS = convert<CleanU, T>(IRB, RHS);
467 Value *Result = nullptr;
468
469 if constexpr (std::is_integral_v<T>) {
470 Result = IOp(IRB, LHSVal, RHS);
471 } else {
472 static_assert(std::is_floating_point_v<T>, "Unsupported type");
473 Result = FOp(IRB, LHSVal, RHS);
474 }
475
476 LHS.storeValue(Result);
477 return LHS;
478}
479
480// Helper function for comparison operations on Var types
481template <typename T, typename U, typename IntOp, typename FPOp>
483 FuncBase &Fn = L.Fn;
484 if (&Fn != &R.Fn)
485 PROTEUS_FATAL_ERROR("Variables should belong to the same function");
486
487 auto &IRB = Fn.getIRBuilder();
488
489 Value *LHS = L.loadValue();
490 Value *RHS = R.loadValue();
491
492 RHS = convert<U, T>(IRB, RHS);
493
494 Value *Result = nullptr;
495 if constexpr (std::is_integral_v<T>) {
496 Result = IOp(IRB, LHS, RHS);
497 } else {
498 static_assert(std::is_floating_point_v<T>, "Unsupported type");
499 Result = FOp(IRB, LHS, RHS);
500 }
501
502 auto ResultVar = Fn.declVar<bool>("res.");
503 ResultVar.storeValue(Result);
504
505 return ResultVar;
506}
507
508template <typename T>
509template <typename U>
512 // Allocate storage for the target type T.
513 Type *TargetTy = TypeMap<T>::get(Fn.getFunction()->getContext());
514 auto *Alloca = Fn.emitAlloca(TargetTy, "conv.var");
515 Storage = std::make_unique<ScalarStorage>(Alloca, Fn.getIRBuilder());
516 *this = V;
517}
518
519template <typename T>
522 storeValue(V.loadValue());
523 return *this;
524}
525
526template <typename T>
529 if (this->Storage == nullptr) {
530 // If we don't have storage, clone it from the source.
531 Storage = V.Storage->clone();
532 } else {
533 // If we have storage, copy the value.
534 storeValue(V.loadValue());
535 }
536 return *this;
537}
538
539template <typename T>
542 auto &IRB = Fn.getIRBuilder();
543
544 Value *Slot = getSlot();
545 auto *SlotPtrTy = cast<PointerType>(Slot->getType());
546 Type *ElemTy = getAllocatedType();
547
548 unsigned AddrSpace = SlotPtrTy->getAddressSpace();
549 Type *PtrTy = PointerType::get(ElemTy, AddrSpace);
550 Value *PtrVal = Slot;
551 if (PtrVal->getType() != PtrTy)
552 PtrVal = IRB.CreateBitCast(Slot, PtrTy);
553
554 auto *PtrSlot = Fn.emitAlloca(PtrTy, "addr.tmp");
555 IRB.CreateStore(PtrVal, PtrSlot);
556
557 std::unique_ptr<PointerStorage> ResultStorage =
558 std::make_unique<PointerStorage>(PtrSlot, IRB, ElemTy);
559
560 return Var<std::add_pointer_t<T>>(std::move(ResultStorage), Fn);
561}
562
563template <typename T>
564template <typename U>
567 auto &IRB = Fn.getIRBuilder();
568 auto *Converted = convert<U, T>(IRB, V.loadValue());
569 storeValue(Converted);
570 return *this;
571}
572
573template <typename T>
574template <typename U>
577 const U &ConstValue) {
578 static_assert(std::is_arithmetic_v<U>,
579 "Can only assign arithmetic types to Var");
580
581 Type *LHSType = getValueType();
582
583 if (LHSType->isIntegerTy()) {
584 storeValue(ConstantInt::get(LHSType, ConstValue));
585 } else if (LHSType->isFloatingPointTy()) {
586 storeValue(ConstantFP::get(LHSType, ConstValue));
587 } else {
588 PROTEUS_FATAL_ERROR("Unsupported type");
589 }
590
591 return *this;
592}
593
594template <typename T>
595template <typename U>
598 const Var<U> &Other) const {
599 return binOp(
600 *this, Other,
601 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateAdd(L, R); },
602 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFAdd(L, R); });
603}
604
605template <typename T>
606template <typename U>
609 const Var<U> &Other) const {
610 return binOp(
611 *this, Other,
612 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateSub(L, R); },
613 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFSub(L, R); });
614}
615
616template <typename T>
617template <typename U>
620 const Var<U> &Other) const {
621 return binOp(
622 *this, Other,
623 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateMul(L, R); },
624 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFMul(L, R); });
625}
626
627template <typename T>
628template <typename U>
631 const Var<U> &Other) const {
632 return binOp(
633 *this, Other,
634 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateSDiv(L, R); },
635 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFDiv(L, R); });
636}
637
638template <typename T>
639template <typename U>
642 const Var<U> &Other) const {
643 return binOp(
644 *this, Other,
645 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateSRem(L, R); },
646 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFRem(L, R); });
647}
648
649// Arithmetic operators with ConstValue
650template <typename T>
651template <typename U>
652std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
654 const U &ConstValue) const {
655 static_assert(std::is_arithmetic_v<U>,
656 "Can only add arithmetic types to Var");
657 Var<U> Tmp = Fn.defVar<U>(ConstValue, "tmp.");
658 return (*this) + Tmp;
659}
660
661template <typename T>
662template <typename U>
663std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
665 const U &ConstValue) const {
666 static_assert(std::is_arithmetic_v<U>,
667 "Can only subtract arithmetic types from Var");
668 Var<U> Tmp = Fn.defVar<U>(ConstValue, "tmp.");
669 return (*this) - Tmp;
670}
671
672template <typename T>
673template <typename U>
674std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
676 const U &ConstValue) const {
677 static_assert(std::is_arithmetic_v<U>,
678 "Can only multiply Var by arithmetic types");
679 Var<U> Tmp = Fn.defVar<U>(ConstValue, "tmp.");
680 return (*this) * Tmp;
681}
682
683template <typename T>
684template <typename U>
685std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
687 const U &ConstValue) const {
688 static_assert(std::is_arithmetic_v<U>,
689 "Can only divide Var by arithmetic types");
690 Var<U> Tmp = Fn.defVar<U>(ConstValue, "tmp.");
691 return (*this) / Tmp;
692}
693
694template <typename T>
695template <typename U>
696std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
698 const U &ConstValue) const {
699 static_assert(std::is_arithmetic_v<U>,
700 "Can only modulo Var by arithmetic types");
701 Var<U> Tmp = Fn.defVar<U>(ConstValue, "tmp.");
702 return (*this) % Tmp;
703}
704
705// Compound assignment operators for Var
706template <typename T>
707template <typename U>
710 const Var<U> &Other) {
711 auto Result = (*this) + Other;
712 *this = Result;
713 return *this;
714}
715
716template <typename T>
717template <typename U>
720 const U &ConstValue) {
721 static_assert(std::is_arithmetic_v<U>,
722 "Can only add arithmetic types to Var");
723 return compoundAssignConst(
724 *this, ConstValue,
725 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateAdd(L, R); },
726 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFAdd(L, R); });
727}
728
729template <typename T>
730template <typename U>
733 const Var<U> &Other) {
734 auto Result = (*this) - Other;
735 *this = Result;
736 return *this;
737}
738
739template <typename T>
740template <typename U>
743 const U &ConstValue) {
744 static_assert(std::is_arithmetic_v<U>,
745 "Can only subtract arithmetic types from Var");
746 return compoundAssignConst(
747 *this, ConstValue,
748 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateSub(L, R); },
749 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFSub(L, R); });
750}
751
752template <typename T>
753template <typename U>
756 const Var<U> &Other) {
757 auto Result = (*this) * Other;
758 *this = Result;
759 return *this;
760}
761
762template <typename T>
763template <typename U>
766 const U &ConstValue) {
767 static_assert(std::is_arithmetic_v<U>,
768 "Can only multiply Var by arithmetic types");
769 return compoundAssignConst(
770 *this, ConstValue,
771 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateMul(L, R); },
772 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFMul(L, R); });
773}
774
775template <typename T>
776template <typename U>
779 const Var<U> &Other) {
780 auto Result = (*this) / Other;
781 *this = Result;
782 return *this;
783}
784
785template <typename T>
786template <typename U>
789 const U &ConstValue) {
790 static_assert(std::is_arithmetic_v<U>,
791 "Can only divide Var by arithmetic types");
792 return compoundAssignConst(
793 *this, ConstValue,
794 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateSDiv(L, R); },
795 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFDiv(L, R); });
796}
797
798template <typename T>
799template <typename U>
802 const Var<U> &Other) {
803 auto Result = (*this) % Other;
804 *this = Result;
805 return *this;
806}
807
808template <typename T>
809template <typename U>
812 const U &ConstValue) {
813 static_assert(std::is_arithmetic_v<U>,
814 "Can only modulo Var by arithmetic types");
815 return compoundAssignConst(
816 *this, ConstValue,
817 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateSRem(L, R); },
818 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateFRem(L, R); });
819}
820
821template <typename T>
824 auto MinusOne = Fn.defVar<T>(static_cast<T>(-1), "minus_one.");
825 return MinusOne * (*this);
826}
827
828template <typename T>
830 auto &IRB = Fn.getIRBuilder();
831 Value *V = loadValue();
832 Value *ResV = nullptr;
833 if constexpr (std::is_same_v<T, bool>) {
834 ResV = IRB.CreateNot(V);
835 } else if constexpr (std::is_integral_v<T>) {
836 Value *Zero = ConstantInt::get(V->getType(), 0);
837 ResV = IRB.CreateICmpEQ(V, Zero);
838 } else {
839 static_assert(std::is_floating_point_v<T>,
840 "Unsupported type for operator!");
841 Value *Zero = ConstantFP::get(V->getType(), 0.0);
842 ResV = IRB.CreateFCmpOEQ(V, Zero);
843 }
844 auto Ret = Fn.declVar<bool>("not.");
845 Ret.storeValue(ResV);
846 return Ret;
847}
848
849template <typename T>
852 auto &IRB = Fn.getIRBuilder();
853 auto *ArrayTy = cast<ArrayType>(getAllocatedType());
854 auto *BasePointer = getSlot();
855
856 // GEP into the array aggregate: [0, Index]
857 auto *GEP = IRB.CreateConstInBoundsGEP2_64(ArrayTy, BasePointer, 0, Index);
858 Type *ElemTy = getValueType();
859 auto *BasePtrTy = cast<PointerType>(BasePointer->getType());
860 unsigned AddrSpace = BasePtrTy->getAddressSpace();
861 Type *ElemPtrTy = PointerType::get(ElemTy, AddrSpace);
862
863 auto *PtrSlot = Fn.emitAlloca(ElemPtrTy, "elem.ptr");
864 IRB.CreateStore(GEP, PtrSlot);
865
866 std::unique_ptr<PointerStorage> ResultStorage =
867 std::make_unique<PointerStorage>(PtrSlot, IRB, ElemTy);
868 return Var<std::remove_extent_t<T>>(std::move(ResultStorage), Fn);
869}
870
871template <typename T>
872template <typename IdxT>
873std::enable_if_t<std::is_integral_v<IdxT>, Var<std::remove_extent_t<T>>>
875 const Var<IdxT> &Index) {
876 auto &IRB = Fn.getIRBuilder();
877 auto *ArrayTy = cast<ArrayType>(getAllocatedType());
878 auto *BasePointer = getSlot();
879
880 Value *IdxVal = Index.loadValue();
881 Value *Zero = llvm::ConstantInt::get(IdxVal->getType(), 0);
882 auto *GEP = IRB.CreateInBoundsGEP(ArrayTy, BasePointer, {Zero, IdxVal});
883 Type *ElemTy = getValueType();
884 auto *BasePtrTy = cast<PointerType>(BasePointer->getType());
885 unsigned AddrSpace = BasePtrTy->getAddressSpace();
886 Type *ElemPtrTy = PointerType::get(ElemTy, AddrSpace);
887
888 auto *PtrSlot = Fn.emitAlloca(ElemPtrTy, "elem.ptr");
889 IRB.CreateStore(GEP, PtrSlot);
890
891 std::unique_ptr<VarStorage> ResultStorage =
892 std::make_unique<PointerStorage>(PtrSlot, IRB, ElemTy);
893 return Var<std::remove_extent_t<T>>(std::move(ResultStorage), Fn);
894}
895
896template <typename T>
899 auto &IRB = Fn.getIRBuilder();
900
901 auto *PointerElemTy =
902 TypeMap<std::remove_pointer_t<T>>::get(Fn.getFunction()->getContext());
903 auto *Ptr = loadPointer();
904 auto *GEP = IRB.CreateConstInBoundsGEP1_64(PointerElemTy, Ptr, Index);
905 unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
906 Type *ElemPtrTy = PointerType::get(PointerElemTy, AddrSpace);
907
908 // Create a pointer storage to hold the LValue for
909 // the Array[Index].
910 auto *PtrSlot = Fn.emitAlloca(ElemPtrTy, "elem.ptr");
911 IRB.CreateStore(GEP, PtrSlot);
912 std::unique_ptr<PointerStorage> ResultStorage =
913 std::make_unique<PointerStorage>(PtrSlot, IRB, PointerElemTy);
914
915 return Var<std::remove_pointer_t<T>>(std::move(ResultStorage), Fn);
916}
917
918template <typename T>
919template <typename IdxT>
920std::enable_if_t<std::is_arithmetic_v<IdxT>, Var<std::remove_pointer_t<T>>>
922 const Var<IdxT> &Index) {
923 auto &IRB = Fn.getIRBuilder();
924
925 auto *PointeeType =
926 TypeMap<std::remove_pointer_t<T>>::get(Fn.getFunction()->getContext());
927 auto *Ptr = loadPointer();
928 auto *IdxValue = Index.loadValue();
929 auto *GEP = IRB.CreateInBoundsGEP(PointeeType, Ptr, IdxValue);
930 unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
931 Type *ElemPtrTy = PointerType::get(PointeeType, AddrSpace);
932
933 // Create a pointer storage to hold the LValue for
934 // the Array[Index].
935 auto *PtrSlot = Fn.emitAlloca(ElemPtrTy, "elem.ptr");
936 IRB.CreateStore(GEP, PtrSlot);
937 std::unique_ptr<PointerStorage> ResultStorage =
938 std::make_unique<PointerStorage>(PtrSlot, IRB, PointeeType);
939
940 return Var<std::remove_pointer_t<T>>(std::move(ResultStorage), Fn);
941}
942
943// Pointer type operator*
944template <typename T>
949
950template <typename T>
953 auto &IRB = Fn.getIRBuilder();
954
955 Value *PtrVal = loadPointer();
956 auto *PtrTy = cast<PointerType>(PtrVal->getType());
957 Type *ElemTy = getValueType();
958
959 unsigned AddrSpace = PtrTy->getAddressSpace();
960 Type *PointeePtrTy = PointerType::get(ElemTy, AddrSpace);
961 Type *TargetPtrTy = PointerType::getUnqual(PointeePtrTy);
962
963 if (PtrVal->getType() != PointeePtrTy) {
964 PtrVal = IRB.CreateBitCast(PtrVal, PointeePtrTy);
965 }
966
967 auto *PtrSlot = Fn.emitAlloca(TargetPtrTy, "addr.ptr.tmp");
968 IRB.CreateStore(PtrVal, PtrSlot);
969
970 std::unique_ptr<PointerStorage> ResultStorage =
971 std::make_unique<PointerStorage>(PtrSlot, IRB, PointeePtrTy);
972
973 return Var<std::add_pointer_t<T>>(std::move(ResultStorage), Fn);
974}
975
976template <typename T>
977template <typename OffsetT>
978std::enable_if_t<std::is_arithmetic_v<OffsetT>,
981 const Var<OffsetT> &Offset) const {
982 auto &IRB = Fn.getIRBuilder();
983
984 auto *OffsetVal = Offset.loadValue();
986
987 auto *BasePtr = loadPointer();
988 auto *ElemTy = getValueType();
989
990 auto *GEP = IRB.CreateInBoundsGEP(ElemTy, BasePtr, IdxVal, "ptr.add");
991
992 unsigned AddrSpace = cast<PointerType>(BasePtr->getType())->getAddressSpace();
993 auto *ElemPtrTy = PointerType::get(ElemTy, AddrSpace);
994 auto *PtrSlot = Fn.emitAlloca(ElemPtrTy, "ptr.add.tmp");
995 IRB.CreateStore(GEP, PtrSlot);
996 std::unique_ptr<PointerStorage> ResultStorage =
997 std::make_unique<PointerStorage>(PtrSlot, IRB, ElemTy);
999 std::move(ResultStorage), Fn);
1000}
1001
1002template <typename T>
1003template <typename OffsetT>
1004std::enable_if_t<std::is_arithmetic_v<OffsetT>,
1007 OffsetT Offset) const {
1008 auto &IRB = Fn.getIRBuilder();
1009 auto *IntTy = IRB.getInt64Ty();
1010 Value *IdxVal = ConstantInt::get(IntTy, Offset);
1011
1012 auto *BasePtr = loadPointer();
1013 auto *ElemTy = getValueType();
1014
1015 auto *GEP = IRB.CreateInBoundsGEP(ElemTy, BasePtr, IdxVal, "ptr.add");
1016
1017 unsigned AddrSpace = cast<PointerType>(BasePtr->getType())->getAddressSpace();
1018 auto *ElemPtrTy = PointerType::get(ElemTy, AddrSpace);
1019 auto *PtrSlot = Fn.emitAlloca(ElemPtrTy, "ptr.add.tmp");
1020 IRB.CreateStore(GEP, PtrSlot);
1021 std::unique_ptr<PointerStorage> ResultStorage =
1022 std::make_unique<PointerStorage>(PtrSlot, IRB, ElemTy);
1024 std::move(ResultStorage), Fn);
1025}
1026
1027// Comparison operators for Var
1028template <typename T>
1029template <typename U>
1030std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1032 const Var<U> &Other) const {
1033 return cmpOp(
1034 *this, Other,
1035 [](IRBuilderBase &B, Value *L, Value *R) {
1036 return B.CreateICmpSGT(L, R);
1037 },
1038 [](IRBuilderBase &B, Value *L, Value *R) {
1039 return B.CreateFCmpOGT(L, R);
1040 });
1041}
1042
1043template <typename T>
1044template <typename U>
1045std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1047 const Var<U> &Other) const {
1048 return cmpOp(
1049 *this, Other,
1050 [](IRBuilderBase &B, Value *L, Value *R) {
1051 return B.CreateICmpSGE(L, R);
1052 },
1053 [](IRBuilderBase &B, Value *L, Value *R) {
1054 return B.CreateFCmpOGE(L, R);
1055 });
1056}
1057
1058template <typename T>
1059template <typename U>
1060std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1062 const Var<U> &Other) const {
1063 return cmpOp(
1064 *this, Other,
1065 [](IRBuilderBase &B, Value *L, Value *R) {
1066 return B.CreateICmpSLT(L, R);
1067 },
1068 [](IRBuilderBase &B, Value *L, Value *R) {
1069 return B.CreateFCmpOLT(L, R);
1070 });
1071}
1072
1073template <typename T>
1074template <typename U>
1075std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1077 const Var<U> &Other) const {
1078 return cmpOp(
1079 *this, Other,
1080 [](IRBuilderBase &B, Value *L, Value *R) {
1081 return B.CreateICmpSLE(L, R);
1082 },
1083 [](IRBuilderBase &B, Value *L, Value *R) {
1084 return B.CreateFCmpOLE(L, R);
1085 });
1086}
1087
1088template <typename T>
1089template <typename U>
1090std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1092 const Var<U> &Other) const {
1093 return cmpOp(
1094 *this, Other,
1095 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateICmpEQ(L, R); },
1096 [](IRBuilderBase &B, Value *L, Value *R) {
1097 return B.CreateFCmpOEQ(L, R);
1098 });
1099}
1100
1101template <typename T>
1102template <typename U>
1103std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1105 const Var<U> &Other) const {
1106 return cmpOp(
1107 *this, Other,
1108 [](IRBuilderBase &B, Value *L, Value *R) { return B.CreateICmpNE(L, R); },
1109 [](IRBuilderBase &B, Value *L, Value *R) {
1110 return B.CreateFCmpONE(L, R);
1111 });
1112}
1113
1114template <typename T>
1115template <typename U>
1116std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1118 const U &ConstValue) const {
1119 Var<U> Tmp = Fn.defVar<U>(ConstValue, "cmp.");
1120 return (*this) > Tmp;
1121}
1122
1123template <typename T>
1124template <typename U>
1125std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1127 const U &ConstValue) const {
1128 Var<U> Tmp = Fn.defVar<U>(ConstValue, "cmp.");
1129 return (*this) >= Tmp;
1130}
1131
1132template <typename T>
1133template <typename U>
1134std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1136 const U &ConstValue) const {
1137 Var<U> Tmp = Fn.defVar<U>(ConstValue, "cmp.");
1138 return (*this) < Tmp;
1139}
1140
1141template <typename T>
1142template <typename U>
1143std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1145 const U &ConstValue) const {
1146 auto Tmp = Fn.defVar<U>(ConstValue, "cmp.");
1147 return (*this) <= Tmp;
1148}
1149
1150template <typename T>
1151template <typename U>
1152std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1154 const U &ConstValue) const {
1155 Var<U> Tmp = Fn.defVar<U>(ConstValue, "cmp.");
1156 return (*this) == Tmp;
1157}
1158
1159template <typename T>
1160template <typename U>
1161std::enable_if_t<std::is_arithmetic_v<U>, Var<bool>>
1163 const U &ConstValue) const {
1164 auto Tmp = Fn.defVar<U>(ConstValue, "cmp.");
1165 return (*this) != Tmp;
1166}
1167
1168// Non-member arithmetic operators for Var
1169template <typename T, typename U>
1170std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
1172operator+(const T &ConstValue, const Var<U> &V) {
1173 Var<T> Tmp = V.Fn.template defVar<T>(ConstValue, "tmp.");
1174 return Tmp + V;
1175}
1176
1177template <typename T, typename U>
1178std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
1180operator-(const T &ConstValue, const Var<U> &V) {
1181 using CommonType = std::common_type_t<T, U>;
1182 Var<CommonType> Tmp = V.Fn.template defVar<CommonType>(ConstValue, "tmp.");
1183 return Tmp - V;
1184}
1185
1186template <typename T, typename U>
1187std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
1189operator*(const T &ConstValue, const Var<U> &V) {
1190 Var<T> Tmp = V.Fn.template defVar<T>(ConstValue, "tmp.");
1191 return Tmp * V;
1192}
1193
1194template <typename T, typename U>
1195std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
1197operator/(const T &ConstValue, const Var<U> &V) {
1198 Var<T> Tmp = V.Fn.template defVar<T>(ConstValue, "tmp.");
1199 return Tmp / V;
1200}
1201
1202template <typename T, typename U>
1203std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
1205operator%(const T &ConstValue, const Var<U> &V) {
1206 Var<T> Tmp = V.Fn.template defVar<T>(ConstValue, "tmp.");
1207 return Tmp % V;
1208}
1209
1210// Atomic operations for Var
1211template <typename T>
1212Var<T> FuncBase::emitAtomic(AtomicRMWInst::BinOp Op, const Var<T *> &Addr,
1213 const Var<T> &Val) {
1214 static_assert(std::is_arithmetic_v<T>, "Atomic ops require arithmetic type");
1215
1216 auto &IRB = getIRBuilder();
1217 auto *Result = IRB.CreateAtomicRMW(
1218 Op, Addr.loadPointer(), Val.loadValue(), MaybeAlign(),
1219 AtomicOrdering::SequentiallyConsistent, SyncScope::SingleThread);
1220
1221 auto Ret = declVar<T>("atomic.rmw.res.");
1222 Ret.storeValue(Result);
1223 return Ret;
1224}
1225
1226template <typename T>
1227std::enable_if_t<std::is_arithmetic_v<T>, Var<T>>
1229 static_assert(std::is_arithmetic_v<T>, "atomicAdd requires arithmetic type");
1230
1231 Type *ValueType = TypeMap<T>::get(getFunction()->getContext());
1232 auto Op =
1233 ValueType->isFloatingPointTy() ? AtomicRMWInst::FAdd : AtomicRMWInst::Add;
1234 return emitAtomic(Op, Addr, Val);
1235}
1236
1237template <typename T>
1238std::enable_if_t<std::is_arithmetic_v<T>, Var<T>>
1240 static_assert(std::is_arithmetic_v<T>, "atomicSub requires arithmetic type");
1241
1242 Type *ValueType = TypeMap<T>::get(getFunction()->getContext());
1243 auto Op =
1244 ValueType->isFloatingPointTy() ? AtomicRMWInst::FSub : AtomicRMWInst::Sub;
1245 return emitAtomic(Op, Addr, Val);
1246}
1247
1248template <typename T>
1249std::enable_if_t<std::is_arithmetic_v<T>, Var<T>>
1251 static_assert(std::is_arithmetic_v<T>, "atomicMax requires arithmetic type");
1252
1253 Type *ValueType = TypeMap<T>::get(getFunction()->getContext());
1254 auto Op =
1255 ValueType->isFloatingPointTy() ? AtomicRMWInst::FMax : AtomicRMWInst::Max;
1256 return emitAtomic(Op, Addr, Val);
1257}
1258
1259template <typename T>
1260std::enable_if_t<std::is_arithmetic_v<T>, Var<T>>
1262 static_assert(std::is_arithmetic_v<T>, "atomicMin requires arithmetic type");
1263
1264 Type *ValueType = TypeMap<T>::get(getFunction()->getContext());
1265 auto Op =
1266 ValueType->isFloatingPointTy() ? AtomicRMWInst::FMin : AtomicRMWInst::Min;
1267 return emitAtomic(Op, Addr, Val);
1268}
1269
1270inline void FuncBase::ret() {
1271 auto *CurBB = IP.getBlock();
1272 if (!CurBB->getSingleSuccessor())
1273 PROTEUS_FATAL_ERROR("Expected single successor for current block");
1274 auto *TermI = CurBB->getTerminator();
1275
1276 auto &IRB = getIRBuilder();
1277 IRB.CreateRetVoid();
1278
1279 TermI->eraseFromParent();
1280}
1281
1282template <typename T> void FuncBase::ret(const Var<T> &RetVal) {
1283 auto *CurBB = IP.getBlock();
1284 if (!CurBB->getSingleSuccessor())
1285 PROTEUS_FATAL_ERROR("Expected single successor for current block");
1286 auto *TermI = CurBB->getTerminator();
1287
1288 auto &IRB = getIRBuilder();
1289 Value *RetValue = RetVal.loadValue();
1290 IRB.CreateRet(RetValue);
1291
1292 TermI->eraseFromParent();
1293}
1294
1295// Helper struct to convert Var operands to a target type T.
1296// Used by emitIntrinsic to convert all operands to the intrinsic's result type.
1297// C++17 doesn't support template parameters on lambdas, so we use a struct.
1298template <typename T> struct IntrinsicOperandConverter {
1300
1301 template <typename U> Value *operator()(const Var<U> &Operand) const {
1302 return convert<U, T>(IRB, Operand.loadValue());
1303 }
1304};
1305
1306// Helper for emitting intrinsics with Var
1307template <typename T, typename... Operands>
1308static Var<T> emitIntrinsic(StringRef IntrinsicName, Type *ResultType,
1309 const Operands &...Ops) {
1310 static_assert(sizeof...(Ops) > 0, "Intrinsic requires at least one operand");
1311
1312 auto &Fn = std::get<0>(std::tie(Ops...)).Fn;
1313 auto CheckFn = [&Fn](const auto &Operand) {
1314 if (&Operand.Fn != &Fn)
1315 PROTEUS_FATAL_ERROR("Variables should belong to the same function");
1316 };
1317 (CheckFn(Ops), ...);
1318
1319 auto &IRB = Fn.getIRBuilder();
1320 auto &M = *Fn.getFunction()->getParent();
1321
1323
1324 FunctionCallee Callee = M.getOrInsertFunction(IntrinsicName, ResultType,
1325 ((void)Ops, ResultType)...);
1326 Value *Call = IRB.CreateCall(Callee, {ConvertOperand(Ops)...});
1327
1328 auto ResultVar = Fn.template declVar<T>("res.");
1329 ResultVar.storeValue(Call);
1330 return ResultVar;
1331}
1332
1333// Math intrinsics for Var
1334template <typename T> Var<float> powf(const Var<float> &L, const Var<T> &R) {
1335 static_assert(std::is_convertible_v<T, float>,
1336 "powf requires floating-point type");
1337 auto &IRB = L.Fn.getIRBuilder();
1338 auto *ResultType = IRB.getFloatTy();
1339 auto RFloat = R.Fn.template convert<float>(R);
1340 std::string IntrinsicName = "llvm.pow.f32";
1341#if PROTEUS_ENABLE_CUDA
1342 if (L.Fn.getTargetModel() == TargetModelType::CUDA)
1343 IntrinsicName = "__nv_powf";
1344#endif
1345
1347}
1348
1349template <typename T> Var<float> sqrtf(const Var<T> &R) {
1350 static_assert(std::is_convertible_v<T, float>,
1351 "sqrtf requires floating-point type");
1352
1353 auto &IRB = R.Fn.getIRBuilder();
1354 auto *ResultType = IRB.getFloatTy();
1355 auto RFloat = R.Fn.template convert<float>(R);
1356 std::string IntrinsicName = "llvm.sqrt.f32";
1357#if PROTEUS_ENABLE_CUDA
1358 if (R.Fn.getTargetModel() == TargetModelType::CUDA)
1359 IntrinsicName = "__nv_sqrtf";
1360#endif
1361
1363}
1364
1365template <typename T> Var<float> expf(const Var<T> &R) {
1366 static_assert(std::is_convertible_v<T, float>,
1367 "expf requires floating-point type");
1368
1369 auto &IRB = R.Fn.getIRBuilder();
1370 auto *ResultType = IRB.getFloatTy();
1371 auto RFloat = R.Fn.template convert<float>(R);
1372 std::string IntrinsicName = "llvm.exp.f32";
1373#if PROTEUS_ENABLE_CUDA
1374 if (R.Fn.getTargetModel() == TargetModelType::CUDA)
1375 IntrinsicName = "__nv_expf";
1376#endif
1377
1379}
1380
1381template <typename T> Var<float> sinf(const Var<T> &R) {
1382 static_assert(std::is_convertible_v<T, float>,
1383 "sinf requires floating-point type");
1384
1385 auto &IRB = R.Fn.getIRBuilder();
1386 auto *ResultType = IRB.getFloatTy();
1387 auto RFloat = R.Fn.template convert<float>(R);
1388 std::string IntrinsicName = "llvm.sin.f32";
1389#if PROTEUS_ENABLE_CUDA
1390 if (R.Fn.getTargetModel() == TargetModelType::CUDA)
1391 IntrinsicName = "__nv_sinf";
1392#endif
1393
1395}
1396
1397template <typename T> Var<float> cosf(const Var<T> &R) {
1398 static_assert(std::is_convertible_v<T, float>,
1399 "cosf requires floating-point type");
1400
1401 auto &IRB = R.Fn.getIRBuilder();
1402 auto *ResultType = IRB.getFloatTy();
1403 auto RFloat = R.Fn.template convert<float>(R);
1404 std::string IntrinsicName = "llvm.cos.f32";
1405#if PROTEUS_ENABLE_CUDA
1406 if (R.Fn.getTargetModel() == TargetModelType::CUDA)
1407 IntrinsicName = "__nv_cosf";
1408#endif
1409
1411}
1412
1413template <typename T> Var<float> fabs(const Var<T> &R) {
1414 static_assert(std::is_convertible_v<T, float>,
1415 "fabs requires floating-point type");
1416
1417 auto &IRB = R.Fn.getIRBuilder();
1418 auto *ResultType = IRB.getFloatTy();
1419 auto RFloat = R.Fn.template convert<float>(R);
1420 std::string IntrinsicName = "llvm.fabs.f32";
1421#if PROTEUS_ENABLE_CUDA
1422 if (R.Fn.getTargetModel() == TargetModelType::CUDA)
1423 IntrinsicName = "__nv_fabsf";
1424#endif
1425
1427}
1428
1429template <typename T> Var<float> truncf(const Var<T> &R) {
1430 static_assert(std::is_convertible_v<T, float>,
1431 "truncf requires floating-point type");
1432
1433 auto &IRB = R.Fn.getIRBuilder();
1434 auto *ResultType = IRB.getFloatTy();
1435 auto RFloat = R.Fn.template convert<float>(R);
1436 std::string IntrinsicName = "llvm.trunc.f32";
1437#if PROTEUS_ENABLE_CUDA
1438 if (R.Fn.getTargetModel() == TargetModelType::CUDA)
1439 IntrinsicName = "__nv_truncf";
1440#endif
1441
1443}
1444
1445template <typename T> Var<float> logf(const Var<T> &R) {
1446 static_assert(std::is_convertible_v<T, float>,
1447 "logf requires floating-point type");
1448
1449 auto &IRB = R.Fn.getIRBuilder();
1450 auto *ResultType = IRB.getFloatTy();
1451 auto RFloat = R.Fn.template convert<float>(R);
1452 std::string IntrinsicName = "llvm.log.f32";
1453#if PROTEUS_ENABLE_CUDA
1454 if (R.Fn.getTargetModel() == TargetModelType::CUDA)
1455 IntrinsicName = "__nv_logf";
1456#endif
1457
1459}
1460
1461template <typename T> Var<float> absf(const Var<T> &R) {
1462 static_assert(std::is_convertible_v<T, float>,
1463 "absf requires floating-point type");
1464
1465 auto &IRB = R.Fn.getIRBuilder();
1466 auto *ResultType = IRB.getFloatTy();
1467 auto RFloat = R.Fn.template convert<float>(R);
1468 std::string IntrinsicName = "llvm.fabs.f32";
1469#if PROTEUS_ENABLE_CUDA
1470 if (R.Fn.getTargetModel() == TargetModelType::CUDA)
1471 IntrinsicName = "__nv_fabsf";
1472#endif
1473
1475}
1476
1477template <typename T>
1478std::enable_if_t<std::is_arithmetic_v<T>, Var<T>> min(const Var<T> &L,
1479 const Var<T> &R) {
1480 static_assert(std::is_arithmetic_v<T>, "min requires arithmetic type");
1481
1482 FuncBase &Fn = L.Fn;
1483 if (&Fn != &R.Fn)
1484 PROTEUS_FATAL_ERROR("Variables should belong to the same function");
1485
1486 Var<T> ResultVar = Fn.declVar<T>("min_res");
1487 ResultVar = R;
1488 Fn.beginIf(L < R);
1489 { ResultVar = L; }
1490 Fn.endIf();
1491 return ResultVar;
1492}
1493
1494template <typename T>
1495std::enable_if_t<std::is_arithmetic_v<T>, Var<T>> max(const Var<T> &L,
1496 const Var<T> &R) {
1497 static_assert(std::is_arithmetic_v<T>, "max requires arithmetic type");
1498
1499 FuncBase &Fn = L.Fn;
1500 if (&Fn != &R.Fn)
1501 PROTEUS_FATAL_ERROR("Variables should belong to the same function");
1502
1503 Var<T> ResultVar = Fn.declVar<T>("max_res");
1504 ResultVar = R;
1505 Fn.beginIf(L > R);
1506 { ResultVar = L; }
1507 Fn.endIf();
1508 return ResultVar;
1509}
1510
1511} // namespace proteus
1512
1513#endif // PROTEUS_FRONTEND_FUNC_HPP
char int void ** Args
Definition CompilerInterfaceHost.cpp:20
#define PROTEUS_FATAL_ERROR(x)
Definition Error.h:7
Definition Dispatcher.hpp:54
Definition Func.hpp:40
Var< T > emitAtomic(AtomicRMWInst::BinOp Op, const Var< T * > &Addr, const Var< T > &Val)
Definition Func.hpp:1212
std::vector< Scope > Scopes
Definition Func.hpp:60
void setName(StringRef NewName)
Definition Func.hpp:233
Var< T > declVar(StringRef Name="var")
Definition Func.hpp:96
auto forLoop(std::initializer_list< Var< T > > Bounds, BodyLambda &&Body={})
Definition Func.hpp:211
Var< T > defVar(const T &Val, StringRef Name="var")
Definition Func.hpp:127
decltype(auto) callBuiltin(BuiltinFuncT &&BuiltinFunc)
Definition Func.hpp:188
void beginIf(const Var< bool > &CondVar, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.cpp:115
auto defRuntimeConsts(ArgT &&...Args)
Definition Func.hpp:145
std::enable_if_t< std::is_convertible_v< T, U >, Var< U > > convert(const Var< T > &V)
Definition Func.hpp:243
std::enable_if_t< std::is_arithmetic_v< T >, Var< T > > atomicMin(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.hpp:1261
void endFunction()
Definition Func.cpp:53
void beginWhile(CondLambda &&Cond, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.hpp:371
void beginFunction(const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.cpp:30
Var< T > declVar(size_t NElem, AddressSpace AS=AddressSpace::DEFAULT, StringRef Name="array_var")
Definition Func.hpp:114
IRBuilder IRB
Definition Func.hpp:44
std::enable_if_t< std::is_arithmetic_v< T >, Var< T > > atomicMax(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.hpp:1250
Function * getFunction()
Definition Func.cpp:66
ScopeKind
Definition Func.hpp:49
void endIf()
Definition Func.cpp:148
std::string Name
Definition Func.hpp:47
FunctionCallee FC
Definition Func.hpp:43
auto buildLoopNest(LoopBuilders &&...Loops)
Definition Func.hpp:219
void ret()
Definition Func.hpp:1270
IRBuilderBase::InsertPoint IP
Definition Func.hpp:45
JitModule & J
Definition Func.hpp:42
void endWhile()
Definition Func.cpp:181
std::enable_if_t< std::is_arithmetic_v< T >, Var< T > > atomicAdd(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.hpp:1228
Value * emitArrayCreate(Type *Ty, AddressSpace AT, StringRef Name)
Definition Func.cpp:85
StringRef getName() const
Definition Func.hpp:231
std::enable_if_t<!std::is_void_v< typename FnSig< Sig >::RetT >, Var< typename FnSig< Sig >::RetT > > call(StringRef Name)
Definition JitFrontend.hpp:272
Var< T > defRuntimeConst(const T &Val, StringRef Name="run.const.var")
Definition Func.hpp:141
void endFor()
Definition Func.cpp:164
std::enable_if_t< std::is_arithmetic_v< T >, Var< T > > atomicSub(const Var< T * > &Addr, const Var< T > &Val)
Definition Func.hpp:1239
TargetModelType getTargetModel() const
Definition Func.cpp:22
std::string toString(ScopeKind Kind)
Definition Func.hpp:62
IRBuilderBase & getIRBuilder()
Definition Func.cpp:24
void beginFor(Var< T > &IterVar, const Var< T > &InitVar, const Var< T > &UpperBound, const Var< T > &IncVar, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.hpp:309
AllocaInst * emitAlloca(Type *Ty, StringRef Name, AddressSpace AS=AddressSpace::DEFAULT)
Definition Func.cpp:73
Var< T > defVar(const Var< U > &Var, StringRef Name="var")
Definition Func.hpp:134
Definition Func.hpp:252
Func(JitModule &J, FunctionCallee FC, Dispatcher &Dispatch)
Definition Func.hpp:287
auto & getArg()
Definition Func.hpp:296
auto getCompiledFunc() const
Definition Func.hpp:300
auto getArgs()
Definition Func.hpp:294
void setCompiledFunc(RetT(*CompiledFuncIn)(ArgT...))
Definition Func.hpp:302
RetT operator()(ArgT... Args)
Definition JitFrontend.hpp:334
void declArgs()
Definition Func.hpp:292
Definition JitFrontend.hpp:29
Definition LoopNest.hpp:12
Definition LoopNest.hpp:47
Definition VarStorage.hpp:10
Definition Helpers.h:138
Definition StorageCache.cpp:24
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator-(const T &ConstValue, const Var< U > &V)
Definition Func.hpp:1180
AddressSpace
Definition AddressSpace.hpp:6
TargetModelType
Definition TargetModel.hpp:14
Var< float > sqrtf(const Var< T > &R)
Definition Func.hpp:1349
std::enable_if_t< std::is_arithmetic_v< T >, Var< T > > max(const Var< T > &L, const Var< T > &R)
Definition Func.hpp:1495
Var< float > fabs(const Var< T > &R)
Definition Func.hpp:1413
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator*(const T &ConstValue, const Var< U > &V)
Definition Func.hpp:1189
Var< float > sinf(const Var< T > &R)
Definition Func.hpp:1381
Var< float > logf(const Var< T > &R)
Definition Func.hpp:1445
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator+(const T &ConstValue, const Var< U > &V)
Definition Func.hpp:1172
Var< float > cosf(const Var< T > &R)
Definition Func.hpp:1397
Var< T, std::enable_if_t< std::is_arithmetic_v< T > > > & compoundAssignConst(Var< T, std::enable_if_t< std::is_arithmetic_v< T > > > &LHS, const U &ConstValue, IntOp IOp, FPOp FOp)
Definition Func.hpp:446
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator/(const T &ConstValue, const Var< U > &V)
Definition Func.hpp:1197
Var< float > expf(const Var< T > &R)
Definition Func.hpp:1365
Var< float > powf(const Var< float > &L, const Var< T > &R)
Definition Func.hpp:1334
Var< float > truncf(const Var< T > &R)
Definition Func.hpp:1429
std::enable_if_t< std::is_arithmetic_v< T >, Var< T > > min(const Var< T > &L, const Var< T > &R)
Definition Func.hpp:1478
Var< float > absf(const Var< T > &R)
Definition Func.hpp:1461
Var< std::common_type_t< T, U > > binOp(const Var< T > &L, const Var< U > &R, IntOp IOp, FPOp FOp)
Definition Func.hpp:415
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator%(const T &ConstValue, const Var< U > &V)
Definition Func.hpp:1205
Var< bool > cmpOp(const Var< T > &L, const Var< U > &R, IntOp IOp, FPOp FOp)
Definition Func.hpp:482
Definition Func.hpp:27
Definition Func.hpp:36
void operator()() const
Definition Func.hpp:37
RetT_ RetT
Definition Func.hpp:31
Definition Func.hpp:28
Definition Func.hpp:50
Scope(const char *File, int Line, ScopeKind Kind, IRBuilderBase::InsertPoint ContIP)
Definition Func.hpp:56
int Line
Definition Func.hpp:52
ScopeKind Kind
Definition Func.hpp:53
std::string File
Definition Func.hpp:51
IRBuilderBase::InsertPoint ContIP
Definition Func.hpp:54
Definition Func.hpp:1298
Value * operator()(const Var< U > &Operand) const
Definition Func.hpp:1301
IRBuilderBase & IRB
Definition Func.hpp:1299
Definition TypeMap.hpp:13
Definition Var.hpp:74
FuncBase & Fn
Definition Var.hpp:75
Definition Var.hpp:94