Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
LoopNest.h
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_LOOP_NEST_H
2#define PROTEUS_FRONTEND_LOOP_NEST_H
3
4#include "proteus/Error.h"
7
8#include <memory>
9#include <optional>
10namespace proteus {
11
12template <typename T> class LoopBoundInfo {
13public:
18
22};
23
24template <typename T, typename BodyLambda> class [[nodiscard]] ForLoopBuilder {
25public:
27 using LoopIndexType = T;
28 std::optional<int> TileSize;
29 bool UnrollEnabled = false;
30 std::optional<int> UnrollCount;
31 BodyLambda Body;
33
35 BodyLambda &&Body)
36 : Bounds(Bounds), Body(std::move(Body)), Fn(Fn) {}
37
38 ForLoopBuilder &tile(int Tile) {
39 if (UnrollEnabled)
41 "Cannot tile a loop that is already marked for unrolling");
42 TileSize = Tile;
43 return *this;
44 }
45
47 if (TileSize.has_value())
49 "Cannot unroll a loop that is already marked for tiling");
50 UnrollEnabled = true;
51 return *this;
52 }
53
54 ForLoopBuilder &unroll(int Count) {
55 if (TileSize.has_value())
57 "Cannot unroll a loop that is already marked for tiling");
58 UnrollEnabled = true;
59 UnrollCount = Count;
60 return *this;
61 }
62
63 void emit() {
64 LoopHints Hints;
65 if (UnrollEnabled) {
66 Hints.Unroll = true;
67 Hints.UnrollCount = UnrollCount;
68 }
69 Fn.beginFor(Bounds.IterVar, Bounds.Init, Bounds.UpperBound, Bounds.Inc,
70 __builtin_FILE(), __builtin_LINE(), Hints);
71 Body();
72 Fn.endFor();
73 }
74};
75
76template <typename T, typename... LoopBuilders> class LoopNestBuilder {
77private:
78 static_assert(sizeof...(LoopBuilders) > 0,
79 "LoopNestBuilder requires at least one loop builder");
80 static_assert((IsForLoopBuilder<LoopBuilders>::value && ...),
81 "LoopNestBuilder requires ForLoopBuilder inputs");
82 static_assert((std::is_same_v<T, typename LoopBuilders::LoopIndexType> &&
83 ...),
84 "All loops in a loop nest must use the same loop index type");
85
86 std::tuple<LoopBuilders...> Loops;
87 std::array<std::unique_ptr<LoopBoundInfo<T>>,
88 std::tuple_size_v<decltype(Loops)>>
89 TiledLoopBounds;
90 FuncBase &Fn;
91
92 template <std::size_t... Is>
93 void setupTiledLoops(std::index_sequence<Is...>) {
94 (
95 // Declare the tile iter and step variables for each tiled loop,
96 // storing them in the TiledLoopBounds array.
97 [&]() {
98 auto &Loop = std::get<Is>(Loops);
99 if (Loop.TileSize.has_value()) {
100 auto &Bounds = std::get<Is>(Loops).Bounds;
101
102 auto TileIter = Fn.declVar<T>("tile_iter_" + std::to_string(Is));
103 auto TileStep = Fn.declVar<T>("tile_step_" + std::to_string(Is));
104
105 TileStep = Loop.TileSize.value();
106 TiledLoopBounds[Is] = std::make_unique<LoopBoundInfo<T>>(
107 TileIter, Bounds.Init, Bounds.UpperBound, TileStep);
108 }
109 }(),
110 ...);
111 }
112
113 template <std::size_t... Is>
114 void beginTiledLoops(std::index_sequence<Is...>) {
115 (
116 // Begin the tiled loops, using the computed tile bounds.
117 [&]() {
118 auto &Loop = std::get<Is>(Loops);
119 if (Loop.TileSize.has_value()) {
120 auto &Bounds = *TiledLoopBounds[Is];
121 Fn.beginFor(Bounds.IterVar, Bounds.Init, Bounds.UpperBound,
122 Bounds.Inc);
123 }
124 }(),
125 ...);
126 }
127
128 template <std::size_t... Is> void emitInnerLoops(std::index_sequence<Is...>) {
129 (
130 // Emit the inner loops, using this tile's iter var as init.
131 [&]() {
132 auto &Loop = std::get<Is>(Loops);
133 if (Loop.TileSize.has_value()) {
134 auto &TiledBounds = *TiledLoopBounds[Is];
135 auto EndCandidate = TiledBounds.IterVar + TiledBounds.Inc;
136 // Clamp to handle partial tiles.
137 EndCandidate = min(EndCandidate, TiledBounds.UpperBound);
138 Fn.beginFor(Loop.Bounds.IterVar, TiledBounds.IterVar, EndCandidate,
139 Loop.Bounds.Inc);
140 } else {
141 Fn.beginFor(Loop.Bounds.IterVar, Loop.Bounds.Init,
142 Loop.Bounds.UpperBound, Loop.Bounds.Inc);
143 }
144 Loop.Body();
145 }(),
146 ...);
147 (
148 [&]() {
149 // Force unpacking so we emit enough endFors.
150 (void)std::get<Is>(Loops);
151 Fn.endFor();
152 }(),
153 ...);
154 }
155
156 template <std::size_t... Is> void endTiledLoops(std::index_sequence<Is...>) {
157 (
158 [&]() {
159 // Close tiled loops in reverse order to properly handle nesting.
160 auto &Loop = std::get<sizeof...(Is) - 1U - Is>(Loops);
161 if (Loop.TileSize.has_value()) {
162 Fn.endFor();
163 }
164 }(),
165 ...);
166 }
167
168 template <std::size_t... Is>
169 void tileImpl(int Tile, std::index_sequence<Is...>) {
170 (std::get<Is>(Loops).tile(Tile), ...);
171 }
172
173 template <std::size_t... Is>
174 void validateNoUnroll(std::index_sequence<Is...>) const {
175 bool AnyUnrolled = (std::get<Is>(Loops).UnrollEnabled || ...);
176 if (AnyUnrolled)
178 "Cannot tile a loop nest containing loops marked for unrolling");
179 }
180
181public:
182 LoopNestBuilder(FuncBase &Fn, LoopBuilders... Loops)
183 : Loops(std::move(Loops)...), Fn(Fn) {}
184
186 auto IdxSeq = std::index_sequence_for<LoopBuilders...>{};
187 validateNoUnroll(IdxSeq);
188 tileImpl(Tile, IdxSeq);
189 return *this;
190 }
191 void emit() {
192 auto IdxSeq = std::index_sequence_for<LoopBuilders...>{};
193 setupTiledLoops(IdxSeq);
194 beginTiledLoops(IdxSeq);
195 emitInnerLoops(IdxSeq);
196 endTiledLoops(IdxSeq);
197 }
198};
199} // namespace proteus
200
201#endif
Definition LoopNest.h:24
ForLoopBuilder & unroll(int Count)
Definition LoopNest.h:54
ForLoopBuilder(const LoopBoundInfo< T > &Bounds, FuncBase &Fn, BodyLambda &&Body)
Definition LoopNest.h:34
FuncBase & Fn
Definition LoopNest.h:32
ForLoopBuilder & tile(int Tile)
Definition LoopNest.h:38
void emit()
Definition LoopNest.h:63
BodyLambda Body
Definition LoopNest.h:31
std::optional< int > UnrollCount
Definition LoopNest.h:30
ForLoopBuilder & unroll()
Definition LoopNest.h:46
std::optional< int > TileSize
Definition LoopNest.h:28
LoopBoundInfo< T > Bounds
Definition LoopNest.h:26
T LoopIndexType
Definition LoopNest.h:27
Definition Func.h:45
void beginFor(Var< IterT > &IterVar, const Var< InitT > &InitVar, const Var< UpperT > &UpperBound, const Var< IncT > &IncVar, const char *File=__builtin_FILE(), int Line=__builtin_LINE(), LoopHints Hints={})
Definition Func.h:345
void endFor()
Definition Func.cpp:48
Definition LoopNest.h:12
Var< T > UpperBound
Definition LoopNest.h:16
Var< T > Init
Definition LoopNest.h:15
Var< T > Inc
Definition LoopNest.h:17
Var< T > IterVar
Definition LoopNest.h:14
LoopBoundInfo(const Var< T > &IterVar, const Var< T > &Init, const Var< T > &UpperBound, const Var< T > &Inc)
Definition LoopNest.h:19
Definition LoopNest.h:76
void emit()
Definition LoopNest.h:191
LoopNestBuilder(FuncBase &Fn, LoopBuilders... Loops)
Definition LoopNest.h:182
LoopNestBuilder & tile(int Tile)
Definition LoopNest.h:185
Definition MemoryCache.h:26
std::enable_if_t< is_arithmetic_unref_v< T >, Var< remove_cvref_t< T > > > min(const Var< T > &L, const Var< T > &R)
Definition Var.h:1149
void reportFatalError(const llvm::Twine &Reason, const char *FILE, unsigned Line)
Definition Error.cpp:14
Definition Hashing.h:158
Definition Func.h:41
Definition CodeBuilder.h:21
bool Unroll
Definition CodeBuilder.h:22
std::optional< int > UnrollCount
Definition CodeBuilder.h:23
Definition Var.h:16