Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
Var.h
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_VAR_H
2#define PROTEUS_FRONTEND_VAR_H
3
4#include "proteus/Error.h"
8
9#include <string>
10#include <type_traits>
11#include <vector>
12
13namespace proteus {
14
15// Primary template declaration
16template <typename T, typename = void> struct Var;
17
18// Specialization for arithmetic types (including references to arithmetic
19// types).
20template <typename T>
21struct Var<T, std::enable_if_t<is_scalar_arithmetic_v<T>>> {
26 unsigned AddrSpace = 0;
27
28 using ValueType = T;
29 using ElemType = T;
30
32 : CB(CBIn), Slot(A.Slot), ValueTy(A.ValueTy), AllocTy(A.AllocTy),
33 AddrSpace(A.AddrSpace) {}
34
35 // Conversion constructor from Var<U> where U can convert to T.
36 template <typename U,
37 typename = std::enable_if_t<std::is_convertible_v<U, T> &&
38 (!std::is_same_v<U, T>)>>
39 Var(const Var<U> &V);
40
41 // Copy constructor: aliases the same alloca slot. This is effectively a
42 // "shallow copy" that creates another Var handle to the same storage.
43 Var(const Var &V) = default;
44 Var(Var &&) = default;
45
46 Var &operator=(Var &&V);
47
48 // Assignment operators
49 Var &operator=(const Var &V);
50
51 template <typename U> Var &operator=(const Var<U> &V);
52
53 template <typename U> Var &operator=(const U &ConstValue);
54
55 Var<std::add_pointer_t<T>> getAddress();
56
60 template <typename U> auto convert() const;
61
62 // Load / store helpers.
63 IRValue *loadValue() const {
64 if constexpr (std::is_reference_v<T>)
65 return CB.loadFromPointee(Slot, AllocTy, ValueTy);
66 else
67 return CB.loadScalar(Slot, ValueTy);
68 }
69 void storeValue(IRValue *Val) {
70 if constexpr (std::is_reference_v<T>)
71 CB.storeToPointee(Slot, AllocTy, Val);
72 else
73 CB.storeScalar(Slot, Val);
74 }
75 IRValue *getSlot() const { return Slot; }
76 IRType getValueType() const { return ValueTy; }
77 IRType getAllocatedType() const { return AllocTy; }
78
79 // Arithmetic operators
80 template <typename U>
82
83 template <typename U>
84 std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
85 operator+(const U &ConstValue) const;
86
87 template <typename U>
89
90 template <typename U>
91 std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
92 operator-(const U &ConstValue) const;
93
94 template <typename U>
96
97 template <typename U>
98 std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
99 operator*(const U &ConstValue) const;
100
101 template <typename U>
102 Var<std::common_type_t<T, U>> operator/(const Var<U> &Other) const;
103
104 template <typename U>
105 std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
106 operator/(const U &ConstValue) const;
107
108 template <typename U>
109 Var<std::common_type_t<T, U>> operator%(const Var<U> &Other) const;
110
111 template <typename U>
112 std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
113 operator%(const U &ConstValue) const;
114
115 // Unary operators
117 Var<bool> operator!() const;
118
119 // Compound assignment operators
120 template <typename U> Var &operator+=(const Var<U> &Other);
121
122 template <typename U> Var &operator+=(const U &ConstValue);
123
124 template <typename U> Var &operator-=(const Var<U> &Other);
125
126 template <typename U> Var &operator-=(const U &ConstValue);
127
128 template <typename U> Var &operator*=(const Var<U> &Other);
129
130 template <typename U> Var &operator*=(const U &ConstValue);
131
132 template <typename U> Var &operator/=(const Var<U> &Other);
133
134 template <typename U> Var &operator/=(const U &ConstValue);
135
136 template <typename U> Var &operator%=(const Var<U> &Other);
137
138 template <typename U> Var &operator%=(const U &ConstValue);
139
140 // Comparison operators
141 template <typename U>
142 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
143 operator>(const Var<U> &Other) const;
144
145 template <typename U>
146 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
147 operator>=(const Var<U> &Other) const;
148
149 template <typename U>
150 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
151 operator<(const Var<U> &Other) const;
152
153 template <typename U>
154 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
155 operator<=(const Var<U> &Other) const;
156
157 template <typename U>
158 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
159 operator==(const Var<U> &Other) const;
160
161 template <typename U>
162 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
163 operator!=(const Var<U> &Other) const;
164
165 template <typename U>
166 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
167 operator>(const U &ConstValue) const;
168
169 template <typename U>
170 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
171 operator>=(const U &ConstValue) const;
172
173 template <typename U>
174 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
175 operator<(const U &ConstValue) const;
176
177 template <typename U>
178 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
179 operator<=(const U &ConstValue) const;
180
181 template <typename U>
182 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
183 operator==(const U &ConstValue) const;
184
185 template <typename U>
186 std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
187 operator!=(const U &ConstValue) const;
188};
189
190// Specialization for array types
191template <typename T> struct Var<T, std::enable_if_t<std::is_array_v<T>>> {
196 unsigned AddrSpace = 0;
197
198 using ValueType = T;
199 using ElemType = std::remove_extent_t<T>;
200
202 : CB(CBIn), Slot(A.Slot), ValueTy(A.ValueTy), AllocTy(A.AllocTy),
203 AddrSpace(A.AddrSpace) {}
204
205 IRValue *getSlot() const { return Slot; }
206 IRType getValueType() const { return ValueTy; }
207 IRType getAllocatedType() const { return AllocTy; }
208
209 // Load / store: loading an entire array is not supported.
211 reportFatalError("Cannot load entire array as a value");
212 return nullptr;
213 }
215 reportFatalError("Cannot store value to entire array");
216 }
217
218 Var<std::add_lvalue_reference_t<ElemType>> operator[](size_t Index);
219
220 template <typename IdxT>
221 std::enable_if_t<std::is_integral_v<IdxT>,
223 operator[](const Var<IdxT> &Index);
224
226};
227
228// Specialization for pointer types (including references to pointers)
229template <typename T> struct Var<T, std::enable_if_t<is_pointer_unref_v<T>>> {
234 unsigned AddrSpace = 0;
235
236 using ValueType = T;
237 using ElemType = std::remove_pointer_t<std::remove_reference_t<T>>;
238
240 : CB(CBIn), Slot(A.Slot), ValueTy(A.ValueTy), AllocTy(A.AllocTy),
241 AddrSpace(A.AddrSpace) {}
242
243 IRValue *getSlot() const { return Slot; }
244 IRType getValueType() const { return ValueTy; }
245 IRType getAllocatedType() const { return AllocTy; }
246
247 // Load / store the pointer value itself from/to the pointer slot.
248 IRValue *loadAddress() const { return CB.loadAddress(Slot, AllocTy); }
249 void storeAddress(IRValue *Ptr) { CB.storeAddress(Slot, Ptr); }
250
251 // Load / store through the pointer (dereference).
253 return CB.loadFromPointee(Slot, AllocTy, ValueTy);
254 }
255 void storeValue(IRValue *Val) { CB.storeToPointee(Slot, AllocTy, Val); }
256
257 Var<std::add_lvalue_reference_t<ElemType>> operator[](size_t Index);
258
259 template <typename IdxT>
260 std::enable_if_t<std::is_arithmetic_v<IdxT>,
262 operator[](const Var<IdxT> &Index);
263
265
267
268 template <typename OffsetT>
269 std::enable_if_t<std::is_arithmetic_v<OffsetT>,
271 operator+(const Var<OffsetT> &Offset) const;
272
273 template <typename OffsetT>
274 std::enable_if_t<std::is_arithmetic_v<OffsetT>,
276 operator+(OffsetT Offset) const;
277
278 template <typename OffsetT>
279 friend std::enable_if_t<std::is_arithmetic_v<OffsetT>,
281 operator+(OffsetT Offset, const Var &Ptr) {
282 return Ptr + Offset;
283 }
284};
285
286// Non-member arithmetic operators for Var
287template <typename T, typename U>
288std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
289 Var<std::common_type_t<T, U>>>
290operator+(const T &ConstValue, const Var<U> &Var);
291
292template <typename T, typename U>
293std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
294 Var<std::common_type_t<T, U>>>
295operator-(const T &ConstValue, const Var<U> &V);
296
297template <typename T, typename U>
298std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
299 Var<std::common_type_t<T, U>>>
300operator*(const T &ConstValue, const Var<U> &V);
301
302template <typename T, typename U>
303std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
304 Var<std::common_type_t<T, U>>>
305operator/(const T &ConstValue, const Var<U> &V);
306
307template <typename T, typename U>
308std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
309 Var<std::common_type_t<T, U>>>
310operator%(const T &ConstValue, const Var<U> &V);
311
312// ---------------------------------------------------------------------------
313// Free helper functions (type conversion)
314// ---------------------------------------------------------------------------
315
316// Value-level type conversion — internal implementation detail.
317// Use Var::convert<U>() for user-facing type conversions.
318namespace detail {
319template <typename FromT, typename ToT>
321 using From = remove_cvref_t<FromT>;
322 using To = remove_cvref_t<ToT>;
323 static_assert(std::is_arithmetic_v<From>, "From type must be arithmetic");
324 static_assert(std::is_arithmetic_v<To>, "To type must be arithmetic");
325
326 if constexpr (std::is_same_v<From, To>)
327 return V;
329}
330} // namespace detail
331
332// Allocate a new Var of type T using CB.
333template <typename T>
334Var<T> declVar(CodeBuilder &CB, const std::string &Name = "var") {
335 static_assert(!std::is_array_v<T>, "Expected non-array type");
336 static_assert(!std::is_reference_v<T>,
337 "declVar does not support reference types");
338
339 if constexpr (std::is_pointer_v<T>) {
341 return Var<T>{CB.allocPointer(Name, ElemIRTy), CB};
342 } else {
343 IRType AllocaIRTy = TypeMap<T>::get();
344 return Var<T>{CB.allocScalar(Name, AllocaIRTy), CB};
345 }
346}
347
348// Allocate and initialize a Var of type T.
349template <typename T>
350Var<T> defVar(CodeBuilder &CB, const T &Val, const std::string &Name = "var") {
351 using RawT = std::remove_const_t<T>;
352 Var<RawT> V = declVar<RawT>(CB, Name);
353 V = Val;
354 return Var<T>(V);
355}
356
357// ---------------------------------------------------------------------------
358// Operator implementation helpers
359// ---------------------------------------------------------------------------
360
361template <typename T, typename U>
362Var<std::common_type_t<remove_cvref_t<T>, remove_cvref_t<U>>>
363binOp(const Var<T> &L, const Var<U> &R, ArithOp Op) {
364 using CommonT = std::common_type_t<remove_cvref_t<T>, remove_cvref_t<U>>;
365
366 CodeBuilder &CB = L.CB;
367 if (&CB != &R.CB)
368 reportFatalError("Variables should belong to the same function");
369
370 IRValue *LHS = detail::convert<T, CommonT>(CB, L.loadValue());
371 IRValue *RHS = detail::convert<U, CommonT>(CB, R.loadValue());
372
373 IRValue *Result = CB.createArith(Op, LHS, RHS, TypeMap<CommonT>::get());
374
375 auto ResultVar = declVar<CommonT>(CB, "res.");
376 ResultVar.storeValue(Result);
377
378 return ResultVar;
379}
380
381template <typename T, typename U>
382Var<T, std::enable_if_t<is_scalar_arithmetic_v<T>>> &
383compoundAssignConst(Var<T, std::enable_if_t<is_scalar_arithmetic_v<T>>> &LHS,
384 const U &ConstValue, ArithOp Op) {
385 static_assert(std::is_convertible_v<remove_cvref_t<U>, remove_cvref_t<T>>,
386 "U must be convertible to T");
387
388 IRType RHSType = TypeMap<remove_cvref_t<U>>::get();
389
390 IRValue *RHS = nullptr;
391 if constexpr (std::is_integral_v<remove_cvref_t<U>>) {
392 RHS = LHS.CB.getConstantInt(RHSType, ConstValue);
393 } else {
394 RHS = LHS.CB.getConstantFP(RHSType, ConstValue);
395 }
396
397 IRValue *LHSVal = LHS.loadValue();
398
399 RHS = detail::convert<U, T>(LHS.CB, RHS);
400 IRValue *Result =
401 LHS.CB.createArith(Op, LHSVal, RHS, TypeMap<remove_cvref_t<T>>::get());
402
403 LHS.storeValue(Result);
404 return LHS;
405}
406
407template <typename T, typename U>
408Var<bool> cmpOp(const Var<T> &L, const Var<U> &R, CmpOp Op) {
409 CodeBuilder &CB = L.CB;
410 if (&CB != &R.CB)
411 reportFatalError("Variables should belong to the same function");
412
413 IRValue *LHS = L.loadValue();
414 IRValue *RHS = detail::convert<U, T>(CB, R.loadValue());
415
416 IRValue *Result =
417 CB.createCmp(Op, LHS, RHS, TypeMap<remove_cvref_t<T>>::get());
418
419 auto ResultVar = declVar<bool>(CB, "res.");
420 ResultVar.storeValue(Result);
421
422 return ResultVar;
423}
424
425// ---------------------------------------------------------------------------
426// Var member operator implementations
427// ---------------------------------------------------------------------------
428
429template <typename T>
430template <typename U>
432 using ResultT = std::remove_reference_t<U>;
433 Var<ResultT> Res = declVar<ResultT>(this->CB, "convert.");
434 IRValue *Converted = detail::convert<T, U>(this->CB, this->loadValue());
435 Res.storeValue(Converted);
436 return Res;
437}
438
439template <typename T>
440template <typename U, typename>
442 : Var(V.CB.allocScalar("conv.var", TypeMap<remove_cvref_t<T>>::get()),
443 V.CB) {
444 auto Converted = detail::convert<U, T>(CB, V.loadValue());
445 storeValue(Converted);
446}
447
448template <typename T>
451 static_assert(is_mutable_v<T>, "Cannot assign to Var<const T>");
452 storeValue(V.loadValue());
453 return *this;
454}
455
456template <typename T>
459 static_assert(is_mutable_v<T>, "Cannot assign to Var<const T>");
460 storeValue(V.loadValue());
461 return *this;
462}
463
464template <typename T>
467 if constexpr (std::is_reference_v<T>) {
468 // For a reference Var the slot holds a pointer; load that pointer and
469 // expose it as the address.
470 IRValue *PtrVal = CB.loadAddress(Slot, AllocTy);
471 auto A = CB.allocPointer("addr.ref.tmp", ValueTy, AddrSpace);
472 CB.storeAddress(A.Slot, PtrVal);
473 return Var<std::add_pointer_t<T>>(A, CB);
474 }
475
476 auto A = CB.allocPointer("addr.tmp", AllocTy, AddrSpace);
477 CB.storeAddress(A.Slot, Slot);
478 return Var<std::add_pointer_t<T>>(A, CB);
479}
480
481template <typename T>
482template <typename U>
485 const Var<U> &V) {
486 static_assert(is_mutable_v<T>, "Cannot assign to Var<const T>");
487 auto Converted = detail::convert<U, T>(CB, V.loadValue());
488 storeValue(Converted);
489 return *this;
490}
491
492template <typename T>
493template <typename U>
496 const U &ConstValue) {
497 static_assert(is_mutable_v<T>, "Cannot assign to Var<const T>");
498 static_assert(std::is_arithmetic_v<U>,
499 "Can only assign arithmetic types to Var");
500
501 IRType LHSType = getValueType();
502
503 if (isIntegerKind(LHSType)) {
504 storeValue(CB.getConstantInt(LHSType, ConstValue));
505 } else if (isFloatingPointKind(LHSType)) {
506 storeValue(CB.getConstantFP(LHSType, ConstValue));
507 } else {
508 reportFatalError("Unsupported type");
509 }
510
511 return *this;
512}
513
514template <typename T>
515template <typename U>
518 const Var<U> &Other) const {
519 return binOp(*this, Other, ArithOp::Add);
520}
521
522template <typename T>
523template <typename U>
526 const Var<U> &Other) const {
527 return binOp(*this, Other, ArithOp::Sub);
528}
529
530template <typename T>
531template <typename U>
534 const Var<U> &Other) const {
535 return binOp(*this, Other, ArithOp::Mul);
536}
537
538template <typename T>
539template <typename U>
542 const Var<U> &Other) const {
543 return binOp(*this, Other, ArithOp::Div);
544}
545
546template <typename T>
547template <typename U>
550 const Var<U> &Other) const {
551 return binOp(*this, Other, ArithOp::Rem);
552}
553
554// Arithmetic operators with ConstValue
555template <typename T>
556template <typename U>
557std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
559 const U &ConstValue) const {
560 static_assert(std::is_arithmetic_v<U>,
561 "Can only add arithmetic types to Var");
562 Var<U> Tmp = defVar<U>(CB, ConstValue, "tmp.");
563 return (*this) + Tmp;
564}
565
566template <typename T>
567template <typename U>
568std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
570 const U &ConstValue) const {
571 static_assert(std::is_arithmetic_v<U>,
572 "Can only subtract arithmetic types from Var");
573 Var<U> Tmp = defVar<U>(CB, ConstValue, "tmp.");
574 return (*this) - Tmp;
575}
576
577template <typename T>
578template <typename U>
579std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
581 const U &ConstValue) const {
582 static_assert(std::is_arithmetic_v<U>,
583 "Can only multiply Var by arithmetic types");
584 Var<U> Tmp = defVar<U>(CB, ConstValue, "tmp.");
585 return (*this) * Tmp;
586}
587
588template <typename T>
589template <typename U>
590std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
592 const U &ConstValue) const {
593 static_assert(std::is_arithmetic_v<U>,
594 "Can only divide Var by arithmetic types");
595 Var<U> Tmp = defVar<U>(CB, ConstValue, "tmp.");
596 return (*this) / Tmp;
597}
598
599template <typename T>
600template <typename U>
601std::enable_if_t<std::is_arithmetic_v<U>, Var<std::common_type_t<T, U>>>
603 const U &ConstValue) const {
604 static_assert(std::is_arithmetic_v<U>,
605 "Can only modulo Var by arithmetic types");
606 Var<U> Tmp = defVar<U>(CB, ConstValue, "tmp.");
607 return (*this) % Tmp;
608}
609
610// Compound assignment operators for Var
611template <typename T>
612template <typename U>
615 const Var<U> &Other) {
616 static_assert(is_mutable_v<T>, "Cannot use += on Var<const T>");
617 auto Result = (*this) + Other;
618 *this = Result;
619 return *this;
620}
621
622template <typename T>
623template <typename U>
626 const U &ConstValue) {
627 static_assert(is_mutable_v<T>, "Cannot use += on Var<const T>");
628 static_assert(std::is_arithmetic_v<U>,
629 "Can only add arithmetic types to Var");
630 return compoundAssignConst(*this, ConstValue, ArithOp::Add);
631}
632
633template <typename T>
634template <typename U>
637 const Var<U> &Other) {
638 static_assert(is_mutable_v<T>, "Cannot use -= on Var<const T>");
639 auto Result = (*this) - Other;
640 *this = Result;
641 return *this;
642}
643
644template <typename T>
645template <typename U>
648 const U &ConstValue) {
649 static_assert(is_mutable_v<T>, "Cannot use -= on Var<const T>");
650 static_assert(std::is_arithmetic_v<U>,
651 "Can only subtract arithmetic types from Var");
652 return compoundAssignConst(*this, ConstValue, ArithOp::Sub);
653}
654
655template <typename T>
656template <typename U>
659 const Var<U> &Other) {
660 static_assert(is_mutable_v<T>, "Cannot use *= on Var<const T>");
661 auto Result = (*this) * Other;
662 *this = Result;
663 return *this;
664}
665
666template <typename T>
667template <typename U>
670 const U &ConstValue) {
671 static_assert(is_mutable_v<T>, "Cannot use *= on Var<const T>");
672 static_assert(std::is_arithmetic_v<U>,
673 "Can only multiply Var by arithmetic types");
674 return compoundAssignConst(*this, ConstValue, ArithOp::Mul);
675}
676
677template <typename T>
678template <typename U>
681 const Var<U> &Other) {
682 static_assert(is_mutable_v<T>, "Cannot use /= on Var<const T>");
683 auto Result = (*this) / Other;
684 *this = Result;
685 return *this;
686}
687
688template <typename T>
689template <typename U>
692 const U &ConstValue) {
693 static_assert(is_mutable_v<T>, "Cannot use /= on Var<const T>");
694 static_assert(std::is_arithmetic_v<U>,
695 "Can only divide Var by arithmetic types");
696 return compoundAssignConst(*this, ConstValue, ArithOp::Div);
697}
698
699template <typename T>
700template <typename U>
703 const Var<U> &Other) {
704 static_assert(is_mutable_v<T>, "Cannot use %= on Var<const T>");
705 auto Result = (*this) % Other;
706 *this = Result;
707 return *this;
708}
709
710template <typename T>
711template <typename U>
714 const U &ConstValue) {
715 static_assert(is_mutable_v<T>, "Cannot use %= on Var<const T>");
716 static_assert(std::is_arithmetic_v<U>,
717 "Can only modulo Var by arithmetic types");
718 return compoundAssignConst(*this, ConstValue, ArithOp::Rem);
719}
720
721template <typename T>
724 auto MinusOne = defVar<remove_cvref_t<T>>(
725 CB, static_cast<remove_cvref_t<T>>(-1), "minus_one.");
726 return MinusOne * (*this);
727}
728
729template <typename T>
732 IRValue *V = loadValue();
733 IRValue *ResV = nullptr;
734 if constexpr (std::is_same_v<remove_cvref_t<T>, bool>) {
735 ResV = CB.createNot(V);
736 } else if constexpr (std::is_integral_v<remove_cvref_t<T>>) {
737 IRValue *Zero = CB.getConstantInt(getValueType(), 0);
738 ResV = CB.createCmp(CmpOp::EQ, V, Zero, getValueType());
739 } else {
740 IRValue *Zero = CB.getConstantFP(getValueType(), 0.0);
741 ResV = CB.createCmp(CmpOp::EQ, V, Zero, getValueType());
742 }
743 auto Ret = declVar<bool>(CB, "not.");
744 Ret.storeValue(ResV);
745 return Ret;
746}
747
748template <typename T>
751 auto A = CB.getElementPtr(Slot, AllocTy, Index, ValueTy);
753}
754
755template <typename T>
756template <typename IdxT>
757std::enable_if_t<std::is_integral_v<IdxT>,
760 const Var<IdxT> &Index) {
761 auto A = CB.getElementPtr(Slot, AllocTy, Index.loadValue(), ValueTy);
763}
764
765template <typename T>
766Var<std::add_lvalue_reference_t<
767 std::remove_pointer_t<std::remove_reference_t<T>>>>
769 using ElemT = std::remove_pointer_t<std::remove_reference_t<T>>;
770 IRType ElemIRTy = TypeMap<ElemT>::get();
771 IRValue *Ptr = CB.loadAddress(Slot, AllocTy);
772 auto A = CB.getElementPtr(Ptr, AllocTy, Index, ElemIRTy);
773 return Var<std::add_lvalue_reference_t<
774 std::remove_pointer_t<std::remove_reference_t<T>>>>(A, CB);
775}
776
777template <typename T>
778template <typename IdxT>
779std::enable_if_t<std::is_arithmetic_v<IdxT>,
780 Var<std::add_lvalue_reference_t<
781 std::remove_pointer_t<std::remove_reference_t<T>>>>>
783 const Var<IdxT> &Index) {
784 using ElemT = std::remove_pointer_t<std::remove_reference_t<T>>;
785 IRType ElemIRTy = TypeMap<ElemT>::get();
786 IRValue *Ptr = CB.loadAddress(Slot, AllocTy);
787 auto A = CB.getElementPtr(Ptr, AllocTy, Index.loadValue(), ElemIRTy);
788 return Var<std::add_lvalue_reference_t<
789 std::remove_pointer_t<std::remove_reference_t<T>>>>(A, CB);
790}
791
792template <typename T>
793Var<std::add_lvalue_reference_t<
794 std::remove_pointer_t<std::remove_reference_t<T>>>>
796 return (*this)[0];
797}
798
799template <typename T>
802 IRValue *PtrVal = CB.loadAddress(Slot, AllocTy);
803 IRType PointeePtrIRTy{IRTypeKind::Pointer, ValueTy.Signed, 0, ValueTy.Kind};
804
805 auto A = CB.allocPointer("addr.ptr.tmp", PointeePtrIRTy, 0);
806 CB.storeAddress(A.Slot, PtrVal);
807 return Var<std::add_pointer_t<T>>(A, CB);
808}
809
810template <typename T>
811template <typename OffsetT>
812std::enable_if_t<std::is_arithmetic_v<OffsetT>,
815 const Var<OffsetT> &Offset) const {
816 IRValue *IdxVal = detail::convert<OffsetT, int64_t>(CB, Offset.loadValue());
817 IRValue *BasePtr = CB.loadAddress(Slot, AllocTy);
818 auto A = CB.getElementPtr(BasePtr, AllocTy, IdxVal, ValueTy);
820}
821
822template <typename T>
823template <typename OffsetT>
824std::enable_if_t<std::is_arithmetic_v<OffsetT>,
827 OffsetT Offset) const {
828 IRValue *IdxVal = CB.getConstantInt(IRType{IRTypeKind::Int64},
829 static_cast<uint64_t>(Offset));
830 IRValue *BasePtr = CB.loadAddress(Slot, AllocTy);
831 auto A = CB.getElementPtr(BasePtr, AllocTy, IdxVal, ValueTy);
833}
834
835// Comparison operators for Var
836template <typename T>
837template <typename U>
838std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
840 const Var<U> &Other) const {
841 return cmpOp(*this, Other, CmpOp::GT);
842}
843
844template <typename T>
845template <typename U>
846std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
848 const Var<U> &Other) const {
849 return cmpOp(*this, Other, CmpOp::GE);
850}
851
852template <typename T>
853template <typename U>
854std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
856 const Var<U> &Other) const {
857 return cmpOp(*this, Other, CmpOp::LT);
858}
859
860template <typename T>
861template <typename U>
862std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
863Var<T, std::enable_if_t<is_scalar_arithmetic_v<T>>>::operator<=(
864 const Var<U> &Other) const {
865 return cmpOp(*this, Other, CmpOp::LE);
866}
867
868template <typename T>
869template <typename U>
870std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
872 const Var<U> &Other) const {
873 return cmpOp(*this, Other, CmpOp::EQ);
874}
875
876template <typename T>
877template <typename U>
878std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
880 const Var<U> &Other) const {
881 return cmpOp(*this, Other, CmpOp::NE);
882}
883
884template <typename T>
885template <typename U>
886std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
888 const U &ConstValue) const {
889 Var<U> Tmp = defVar<U>(CB, ConstValue, "cmp.");
890 return (*this) > Tmp;
891}
892
893template <typename T>
894template <typename U>
895std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
897 const U &ConstValue) const {
898 Var<U> Tmp = defVar<U>(CB, ConstValue, "cmp.");
899 return (*this) >= Tmp;
900}
901
902template <typename T>
903template <typename U>
904std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
906 const U &ConstValue) const {
907 Var<U> Tmp = defVar<U>(CB, ConstValue, "cmp.");
908 return (*this) < Tmp;
909}
910
911template <typename T>
912template <typename U>
913std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
914Var<T, std::enable_if_t<is_scalar_arithmetic_v<T>>>::operator<=(
915 const U &ConstValue) const {
916 auto Tmp = defVar<U>(CB, ConstValue, "cmp.");
917 return (*this) <= Tmp;
918}
919
920template <typename T>
921template <typename U>
922std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
924 const U &ConstValue) const {
925 Var<U> Tmp = defVar<U>(CB, ConstValue, "cmp.");
926 return (*this) == Tmp;
927}
928
929template <typename T>
930template <typename U>
931std::enable_if_t<is_arithmetic_unref_v<U>, Var<bool>>
933 const U &ConstValue) const {
934 auto Tmp = defVar<U>(CB, ConstValue, "cmp.");
935 return (*this) != Tmp;
936}
937
938// Non-member arithmetic operators for Var
939template <typename T, typename U>
940std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
942operator+(const T &ConstValue, const Var<U> &V) {
943 Var<T> Tmp = defVar<T>(V.CB, ConstValue, "tmp.");
944 return Tmp + V;
945}
946
947template <typename T, typename U>
948std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
949 Var<std::common_type_t<T, U>>>
950operator-(const T &ConstValue, const Var<U> &V) {
951 using CommonType = std::common_type_t<T, U>;
952 Var<CommonType> Tmp = defVar<CommonType>(V.CB, ConstValue, "tmp.");
953 return Tmp - V;
954}
955
956template <typename T, typename U>
957std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
958 Var<std::common_type_t<T, U>>>
959operator*(const T &ConstValue, const Var<U> &V) {
960 Var<T> Tmp = defVar<T>(V.CB, ConstValue, "tmp.");
961 return Tmp * V;
962}
963
964template <typename T, typename U>
965std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
966 Var<std::common_type_t<T, U>>>
967operator/(const T &ConstValue, const Var<U> &V) {
968 Var<T> Tmp = defVar<T>(V.CB, ConstValue, "tmp.");
969 return Tmp / V;
970}
971
972template <typename T, typename U>
973std::enable_if_t<std::is_arithmetic_v<T> && std::is_arithmetic_v<U>,
974 Var<std::common_type_t<T, U>>>
975operator%(const T &ConstValue, const Var<U> &V) {
976 Var<T> Tmp = defVar<T>(V.CB, ConstValue, "tmp.");
977 return Tmp % V;
978}
979
980// ---------------------------------------------------------------------------
981// Intrinsic emission helpers
982// ---------------------------------------------------------------------------
983
984// Helper struct to convert Var operands to a target type T.
985template <typename T> struct IntrinsicOperandConverter {
987
988 template <typename U> IRValue *operator()(const Var<U> &Operand) const {
989 return detail::convert<U, T>(CB, Operand.loadValue());
990 }
991};
992
993template <typename T, typename... Operands>
994static Var<T> emitIntrinsic(const std::string &IntrinsicName,
995 const Operands &...Ops) {
996 static_assert(sizeof...(Ops) > 0, "Intrinsic requires at least one operand");
997
998 CodeBuilder &CB = std::get<0>(std::tie(Ops...)).CB;
999 auto CheckFn = [&CB](const auto &Operand) {
1000 if (&Operand.CB != &CB)
1001 reportFatalError("Variables should belong to the same function");
1002 };
1003 (CheckFn(Ops), ...);
1004
1005 IntrinsicOperandConverter<T> ConvertOperand{CB};
1006
1007 IRType ResultIRTy = TypeMap<T>::get();
1008 std::vector<IRType> ArgTys(sizeof...(Ops), ResultIRTy);
1009 IRValue *Call = CB.createCall(IntrinsicName, ResultIRTy, ArgTys,
1010 {ConvertOperand(Ops)...});
1011
1012 auto ResultVar = declVar<T>(CB, "res.");
1013 ResultVar.storeValue(Call);
1014 return ResultVar;
1015}
1016
1017// ---------------------------------------------------------------------------
1018// Math intrinsics for Var
1019// ---------------------------------------------------------------------------
1020
1021template <typename T> Var<float> powf(const Var<float> &L, const Var<T> &R) {
1022 static_assert(std::is_convertible_v<T, float>,
1023 "powf requires floating-point type");
1024
1025 auto RFloat = R.template convert<float>();
1026 std::string IntrinsicName = "llvm.pow.f32";
1027#if PROTEUS_ENABLE_CUDA
1028 if (L.CB.getTargetModel() == TargetModelType::CUDA)
1029 IntrinsicName = "__nv_powf";
1030#endif
1031
1032 return emitIntrinsic<float>(IntrinsicName, L, RFloat);
1033}
1034
1035template <typename T> Var<float> sqrtf(const Var<T> &R) {
1036 static_assert(std::is_convertible_v<T, float>,
1037 "sqrtf requires floating-point type");
1038
1039 auto RFloat = R.template convert<float>();
1040 std::string IntrinsicName = "llvm.sqrt.f32";
1041#if PROTEUS_ENABLE_CUDA
1042 if (R.CB.getTargetModel() == TargetModelType::CUDA)
1043 IntrinsicName = "__nv_sqrtf";
1044#endif
1045
1046 return emitIntrinsic<float>(IntrinsicName, RFloat);
1047}
1048
1049template <typename T> Var<float> expf(const Var<T> &R) {
1050 static_assert(std::is_convertible_v<T, float>,
1051 "expf requires floating-point type");
1052
1053 auto RFloat = R.template convert<float>();
1054 std::string IntrinsicName = "llvm.exp.f32";
1055#if PROTEUS_ENABLE_CUDA
1056 if (R.CB.getTargetModel() == TargetModelType::CUDA)
1057 IntrinsicName = "__nv_expf";
1058#endif
1059
1060 return emitIntrinsic<float>(IntrinsicName, RFloat);
1061}
1062
1063template <typename T> Var<float> sinf(const Var<T> &R) {
1064 static_assert(std::is_convertible_v<T, float>,
1065 "sinf requires floating-point type");
1066
1067 auto RFloat = R.template convert<float>();
1068 std::string IntrinsicName = "llvm.sin.f32";
1069#if PROTEUS_ENABLE_CUDA
1070 if (R.CB.getTargetModel() == TargetModelType::CUDA)
1071 IntrinsicName = "__nv_sinf";
1072#endif
1073
1074 return emitIntrinsic<float>(IntrinsicName, RFloat);
1075}
1076
1077template <typename T> Var<float> cosf(const Var<T> &R) {
1078 static_assert(std::is_convertible_v<T, float>,
1079 "cosf requires floating-point type");
1080
1081 auto RFloat = R.template convert<float>();
1082 std::string IntrinsicName = "llvm.cos.f32";
1083#if PROTEUS_ENABLE_CUDA
1084 if (R.CB.getTargetModel() == TargetModelType::CUDA)
1085 IntrinsicName = "__nv_cosf";
1086#endif
1087
1088 return emitIntrinsic<float>(IntrinsicName, RFloat);
1089}
1090
1091template <typename T> Var<float> fabs(const Var<T> &R) {
1092 static_assert(std::is_convertible_v<T, float>,
1093 "fabs requires floating-point type");
1094
1095 auto RFloat = R.template convert<float>();
1096 std::string IntrinsicName = "llvm.fabs.f32";
1097#if PROTEUS_ENABLE_CUDA
1098 if (R.CB.getTargetModel() == TargetModelType::CUDA)
1099 IntrinsicName = "__nv_fabsf";
1100#endif
1101
1102 return emitIntrinsic<float>(IntrinsicName, RFloat);
1103}
1104
1105template <typename T> Var<float> truncf(const Var<T> &R) {
1106 static_assert(std::is_convertible_v<T, float>,
1107 "truncf requires floating-point type");
1108
1109 auto RFloat = R.template convert<float>();
1110 std::string IntrinsicName = "llvm.trunc.f32";
1111#if PROTEUS_ENABLE_CUDA
1112 if (R.CB.getTargetModel() == TargetModelType::CUDA)
1113 IntrinsicName = "__nv_truncf";
1114#endif
1115
1116 return emitIntrinsic<float>(IntrinsicName, RFloat);
1117}
1118
1119template <typename T> Var<float> logf(const Var<T> &R) {
1120 static_assert(std::is_convertible_v<T, float>,
1121 "logf requires floating-point type");
1122
1123 auto RFloat = R.template convert<float>();
1124 std::string IntrinsicName = "llvm.log.f32";
1125#if PROTEUS_ENABLE_CUDA
1126 if (R.CB.getTargetModel() == TargetModelType::CUDA)
1127 IntrinsicName = "__nv_logf";
1128#endif
1129
1130 return emitIntrinsic<float>(IntrinsicName, RFloat);
1131}
1132
1133template <typename T> Var<float> absf(const Var<T> &R) {
1134 static_assert(std::is_convertible_v<T, float>,
1135 "absf requires floating-point type");
1136
1137 auto RFloat = R.template convert<float>();
1138 std::string IntrinsicName = "llvm.fabs.f32";
1139#if PROTEUS_ENABLE_CUDA
1140 if (R.CB.getTargetModel() == TargetModelType::CUDA)
1141 IntrinsicName = "__nv_fabsf";
1142#endif
1143
1144 return emitIntrinsic<float>(IntrinsicName, RFloat);
1145}
1146
1147template <typename T>
1148std::enable_if_t<is_arithmetic_unref_v<T>, Var<remove_cvref_t<T>>>
1149min(const Var<T> &L, const Var<T> &R) {
1150 CodeBuilder &CB = L.CB;
1151 if (&CB != &R.CB)
1152 reportFatalError("Variables should belong to the same function");
1153
1154 auto ResultVar = declVar<remove_cvref_t<T>>(CB, "min_res");
1155 ResultVar = R;
1156 auto CondVar = L < R;
1157 CB.beginIf(CondVar.loadValue(), __builtin_FILE(), __builtin_LINE());
1158 { ResultVar = L; }
1159 CB.endIf();
1160 return ResultVar;
1161}
1162
1163template <typename T>
1164std::enable_if_t<is_arithmetic_unref_v<T>, Var<remove_cvref_t<T>>>
1165max(const Var<T> &L, const Var<T> &R) {
1166 CodeBuilder &CB = L.CB;
1167 if (&CB != &R.CB)
1168 reportFatalError("Variables should belong to the same function");
1169
1170 auto ResultVar = declVar<remove_cvref_t<T>>(CB, "max_res");
1171 ResultVar = R;
1172 auto CondVar = L > R;
1173 CB.beginIf(CondVar.loadValue(), __builtin_FILE(), __builtin_LINE());
1174 { ResultVar = L; }
1175 CB.endIf();
1176 return ResultVar;
1177}
1178
1179} // namespace proteus
1180
1181#endif // PROTEUS_FRONTEND_VAR_H
Definition CodeBuilder.h:66
virtual VarAlloc allocPointer(const std::string &Name, IRType ElemTy, unsigned AddrSpace=0)=0
virtual IRValue * loadFromPointee(IRValue *Slot, IRType AllocTy, IRType ValueTy)=0
Dereference the pointer stored in Slot, then load the pointee.
virtual IRValue * createArith(ArithOp Op, IRValue *LHS, IRValue *RHS, IRType Ty)=0
virtual void storeAddress(IRValue *Slot, IRValue *Addr)=0
Store Addr into Slot (pointer alloca).
virtual VarAlloc allocScalar(const std::string &Name, IRType ValueTy)=0
virtual void beginIf(IRValue *Cond, const char *File, int Line)=0
virtual void storeScalar(IRValue *Slot, IRValue *Val)=0
Store Val directly into Slot (scalar alloca).
virtual IRValue * createCmp(CmpOp Op, IRValue *LHS, IRValue *RHS, IRType Ty)=0
virtual IRValue * loadAddress(IRValue *Slot, IRType AllocTy)=0
Load the pointer stored in Slot (pointer alloca).
virtual IRValue * loadScalar(IRValue *Slot, IRType ValueTy)=0
Load the value stored directly in Slot (scalar alloca).
virtual IRValue * createCast(IRValue *V, IRType FromTy, IRType ToTy)=0
virtual void storeToPointee(IRValue *Slot, IRType AllocTy, IRValue *Val)=0
Dereference the pointer stored in Slot, then store Val to it.
virtual void endIf()=0
Definition IRValue.h:15
IRValue * convert(CodeBuilder &CB, IRValue *V)
Definition Var.h:320
Definition MemoryCache.h:26
std::remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition TypeTraits.h:11
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator-(const T &ConstValue, const Var< U > &V)
Definition Var.h:950
std::enable_if_t< is_arithmetic_unref_v< T >, Var< remove_cvref_t< T > > > min(const Var< T > &L, const Var< T > &R)
Definition Var.h:1149
Var< T, std::enable_if_t< is_scalar_arithmetic_v< T > > > & compoundAssignConst(Var< T, std::enable_if_t< is_scalar_arithmetic_v< T > > > &LHS, const U &ConstValue, ArithOp Op)
Definition Var.h:383
Var< float > sqrtf(const Var< T > &R)
Definition Var.h:1035
ArithOp
Semantic arithmetic operation selector.
Definition CodeBuilder.h:43
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator+(const T &ConstValue, const Var< U > &Var)
Definition Var.h:942
Var< float > fabs(const Var< T > &R)
Definition Var.h:1091
bool isFloatingPointKind(const IRType &T)
Returns true when T is a floating-point kind (Float or Double).
Definition IRType.h:57
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator*(const T &ConstValue, const Var< U > &V)
Definition Var.h:959
Var< bool > cmpOp(const Var< T > &L, const Var< U > &R, CmpOp Op)
Definition Var.h:408
Var< std::common_type_t< remove_cvref_t< T >, remove_cvref_t< U > > > binOp(const Var< T > &L, const Var< U > &R, ArithOp Op)
Definition Var.h:363
CmpOp
Semantic comparison operation selector.
Definition CodeBuilder.h:46
Var< float > sinf(const Var< T > &R)
Definition Var.h:1063
Var< T > declVar(CodeBuilder &CB, const std::string &Name="var")
Definition Var.h:334
Var< float > logf(const Var< T > &R)
Definition Var.h:1119
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
static int int Offset
Definition JitInterface.h:102
Var< float > cosf(const Var< T > &R)
Definition Var.h:1077
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator/(const T &ConstValue, const Var< U > &V)
Definition Var.h:967
Var< float > expf(const Var< T > &R)
Definition Var.h:1049
Var< float > powf(const Var< float > &L, const Var< T > &R)
Definition Var.h:1021
Var< float > truncf(const Var< T > &R)
Definition Var.h:1105
Var< T > defVar(CodeBuilder &CB, const T &Val, const std::string &Name="var")
Definition Var.h:350
Var< float > absf(const Var< T > &R)
Definition Var.h:1133
std::enable_if_t< is_arithmetic_unref_v< T >, Var< remove_cvref_t< T > > > max(const Var< T > &L, const Var< T > &R)
Definition Var.h:1165
bool isIntegerKind(const IRType &T)
Returns true when T is an integer kind (Int1, Int16, Int32, or Int64).
Definition IRType.h:51
std::enable_if_t< std::is_arithmetic_v< T > &&std::is_arithmetic_v< U >, Var< std::common_type_t< T, U > > > operator%(const T &ConstValue, const Var< U > &V)
Definition Var.h:975
Definition Hashing.h:158
Definition IRType.h:34
bool Signed
Signedness of the type (meaningful for integer kinds and pointer-to-int).
Definition IRType.h:38
CodeBuilder & CB
Definition Var.h:986
IRValue * operator()(const Var< U > &Operand) const
Definition Var.h:988
Definition TypeMap.h:15
Definition CodeBuilder.h:29
IRType ValueTy
Pointee (element) type.
Definition Var.h:232
void storeAddress(IRValue *Ptr)
Definition Var.h:249
Var(VarAlloc A, CodeBuilder &CBIn)
Definition Var.h:239
friend std::enable_if_t< std::is_arithmetic_v< OffsetT >, Var< T, std::enable_if_t< is_pointer_unref_v< T > > > > operator+(OffsetT Offset, const Var &Ptr)
Definition Var.h:281
std::remove_pointer_t< std::remove_reference_t< T > > ElemType
Definition Var.h:237
std::enable_if_t< std::is_arithmetic_v< IdxT >, Var< std::add_lvalue_reference_t< ElemType > > > operator[](const Var< IdxT > &Index)
IRType AllocTy
Type of the pointer alloca.
Definition Var.h:233
Var(VarAlloc A, CodeBuilder &CBIn)
Definition Var.h:31
Var(VarAlloc A, CodeBuilder &CBIn)
Definition Var.h:201
Var< std::add_pointer_t< ValueType > > getAddress() const =delete
std::enable_if_t< std::is_integral_v< IdxT >, Var< std::add_lvalue_reference_t< ElemType > > > operator[](const Var< IdxT > &Index)
std::remove_extent_t< T > ElemType
Definition Var.h:199
IRType ValueTy
Element type.
Definition Var.h:194
IRValue * loadValue() const
Definition Var.h:210
Definition Var.h:16