78 static_assert(
sizeof...(LoopBuilders) > 0,
79 "LoopNestBuilder requires at least one loop builder");
81 "LoopNestBuilder requires ForLoopBuilder inputs");
82 static_assert((std::is_same_v<T, typename LoopBuilders::LoopIndexType> &&
84 "All loops in a loop nest must use the same loop index type");
86 std::tuple<LoopBuilders...> Loops;
87 std::array<std::unique_ptr<LoopBoundInfo<T>>,
88 std::tuple_size_v<
decltype(Loops)>>
92 template <std::size_t... Is>
93 void setupTiledLoops(std::index_sequence<Is...>) {
98 auto &Loop = std::get<Is>(Loops);
99 if (Loop.TileSize.has_value()) {
100 auto &Bounds = std::get<Is>(Loops).Bounds;
102 auto TileIter = Fn.declVar<T>(
"tile_iter_" + std::to_string(Is));
103 auto TileStep = Fn.declVar<T>(
"tile_step_" + std::to_string(Is));
105 TileStep = Loop.TileSize.value();
106 TiledLoopBounds[Is] = std::make_unique<LoopBoundInfo<T>>(
107 TileIter, Bounds.Init, Bounds.UpperBound, TileStep);
113 template <std::size_t... Is>
114 void beginTiledLoops(std::index_sequence<Is...>) {
118 auto &Loop = std::get<Is>(Loops);
119 if (Loop.TileSize.has_value()) {
120 auto &Bounds = *TiledLoopBounds[Is];
121 Fn.beginFor(Bounds.IterVar, Bounds.Init, Bounds.UpperBound,
128 template <std::size_t... Is>
void emitInnerLoops(std::index_sequence<Is...>) {
132 auto &Loop = std::get<Is>(Loops);
133 if (Loop.TileSize.has_value()) {
134 auto &TiledBounds = *TiledLoopBounds[Is];
135 auto EndCandidate = TiledBounds.IterVar + TiledBounds.Inc;
137 EndCandidate =
min(EndCandidate, TiledBounds.UpperBound);
138 Fn.beginFor(Loop.Bounds.IterVar, TiledBounds.IterVar, EndCandidate,
141 Fn.beginFor(Loop.Bounds.IterVar, Loop.Bounds.Init,
142 Loop.Bounds.UpperBound, Loop.Bounds.Inc);
150 (void)std::get<Is>(Loops);
156 template <std::size_t... Is>
void endTiledLoops(std::index_sequence<Is...>) {
160 auto &Loop = std::get<
sizeof...(Is) - 1U - Is>(Loops);
161 if (Loop.TileSize.has_value()) {
168 template <std::size_t... Is>
169 void tileImpl(
int Tile, std::index_sequence<Is...>) {
170 (std::get<Is>(Loops).tile(Tile), ...);
173 template <std::size_t... Is>
174 void validateNoUnroll(std::index_sequence<Is...>)
const {
175 bool AnyUnrolled = (std::get<Is>(Loops).UnrollEnabled || ...);
178 "Cannot tile a loop nest containing loops marked for unrolling");
183 : Loops(
std::move(Loops)...), Fn(Fn) {}
186 auto IdxSeq = std::index_sequence_for<LoopBuilders...>{};
187 validateNoUnroll(IdxSeq);
188 tileImpl(Tile, IdxSeq);
192 auto IdxSeq = std::index_sequence_for<LoopBuilders...>{};
193 setupTiledLoops(IdxSeq);
194 beginTiledLoops(IdxSeq);
195 emitInnerLoops(IdxSeq);
196 endTiledLoops(IdxSeq);
void beginFor(Var< IterT > &IterVar, const Var< InitT > &InitVar, const Var< UpperT > &UpperBound, const Var< IncT > &IncVar, const char *File=__builtin_FILE(), int Line=__builtin_LINE(), LoopHints Hints={})
Definition Func.h:345