Proteus
Programmable JIT compilation and optimization for C/C++ using LLVM
Loading...
Searching...
No Matches
LoopNest.hpp
Go to the documentation of this file.
1#ifndef PROTEUS_FRONTEND_LOOP_NEST_HPP
2#define PROTEUS_FRONTEND_LOOP_NEST_HPP
3
4#include <memory>
5#include <optional>
6
9
10namespace proteus {
11
12template <typename T> class LoopBoundInfo {
13public:
18
22};
23
24template <typename T, typename BodyLambda> class ForLoopBuilder {
25public:
28 std::optional<int> TileSize;
31
36 TileSize = Tile;
37 return *this;
38 }
39
40 void emit() {
41 Fn.beginFor(Bounds.IterVar, Bounds.Init, Bounds.UpperBound, Bounds.Inc);
42 Body();
43 Fn.endFor();
44 }
45};
46
47template <typename T, typename... LoopBuilders> class LoopNestBuilder {
48private:
49 std::tuple<LoopBuilders...> Loops;
50 std::array<std::unique_ptr<LoopBoundInfo<T>>,
51 std::tuple_size_v<decltype(Loops)>>
52 TiledLoopBounds;
53 FuncBase &Fn;
54
55 template <std::size_t... Is>
56 void setupTiledLoops(std::index_sequence<Is...>) {
57 (
58 // Declare the tile iter and step variables for each tiled loop,
59 // storing them in the TiledLoopBounds array.
60 [&]() {
61 auto &Loop = std::get<Is>(Loops);
62 if (Loop.TileSize.has_value()) {
63 auto &Bounds = std::get<Is>(Loops).Bounds;
64
65 auto TileIter = Fn.declVar<T>("tile_iter_" + std::to_string(Is));
66 auto TileStep = Fn.declVar<T>("tile_step_" + std::to_string(Is));
67
68 TileStep = Loop.TileSize.value();
69 TiledLoopBounds[Is] = std::make_unique<LoopBoundInfo<T>>(
70 TileIter, Bounds.Init, Bounds.UpperBound, TileStep);
71 }
72 }(),
73 ...);
74 }
75
76 template <std::size_t... Is>
77 void beginTiledLoops(std::index_sequence<Is...>) {
78 (
79 // Begin the tiled loops, using the computed tile bounds.
80 [&]() {
81 auto &Loop = std::get<Is>(Loops);
82 if (Loop.TileSize.has_value()) {
83 auto &Bounds = *TiledLoopBounds[Is];
84 Fn.beginFor(Bounds.IterVar, Bounds.Init, Bounds.UpperBound,
85 Bounds.Inc);
86 }
87 }(),
88 ...);
89 }
90
91 template <std::size_t... Is> void emitInnerLoops(std::index_sequence<Is...>) {
92 (
93 // Emit the inner loops, using this tile's iter var as init.
94 [&]() {
95 auto &Loop = std::get<Is>(Loops);
96 if (Loop.TileSize.has_value()) {
97 auto &TiledBounds = *TiledLoopBounds[Is];
98 auto EndCandidate = TiledBounds.IterVar + TiledBounds.Inc;
99 // Clamp to handle partial tiles.
101 Fn.beginFor(Loop.Bounds.IterVar, TiledBounds.IterVar, EndCandidate,
102 Loop.Bounds.Inc);
103 } else {
104 Fn.beginFor(Loop.Bounds.IterVar, Loop.Bounds.Init,
105 Loop.Bounds.UpperBound, Loop.Bounds.Inc);
106 }
107 Loop.Body();
108 }(),
109 ...);
110 (
111 [&]() {
112 // Force unpacking so we emit enough endFors.
113 (void)std::get<Is>(Loops);
114 Fn.endFor();
115 }(),
116 ...);
117 }
118
119 template <std::size_t... Is> void endTiledLoops(std::index_sequence<Is...>) {
120 (
121 [&]() {
122 // Close tiled loops in reverse order to properly handle nesting.
123 auto &Loop = std::get<sizeof...(Is) - 1U - Is>(Loops);
124 if (Loop.TileSize.has_value()) {
125 Fn.endFor();
126 }
127 }(),
128 ...);
129 }
130
131 template <std::size_t... Is>
132 void tileImpl(int Tile, std::index_sequence<Is...>) {
133 (std::get<Is>(Loops).tile(Tile), ...);
134 }
135
136public:
138 : Loops(std::move(Loops)...), Fn(Fn) {}
139
141 auto IdxSeq = std::index_sequence_for<LoopBuilders...>{};
142 tileImpl(Tile, IdxSeq);
143 return *this;
144 }
145 void emit() {
146 auto IdxSeq = std::index_sequence_for<LoopBuilders...>{};
147 setupTiledLoops(IdxSeq);
148 beginTiledLoops(IdxSeq);
149 emitInnerLoops(IdxSeq);
150 endTiledLoops(IdxSeq);
151 }
152};
153} // namespace proteus
154
155#endif
Definition LoopNest.hpp:24
ForLoopBuilder(const LoopBoundInfo< T > &Bounds, FuncBase &Fn, BodyLambda &&Body)
Definition LoopNest.hpp:32
FuncBase & Fn
Definition LoopNest.hpp:30
ForLoopBuilder & tile(int Tile)
Definition LoopNest.hpp:35
void emit()
Definition LoopNest.hpp:40
BodyLambda Body
Definition LoopNest.hpp:29
std::optional< int > TileSize
Definition LoopNest.hpp:28
LoopBoundInfo< T > Bounds
Definition LoopNest.hpp:26
T LoopIndexType
Definition LoopNest.hpp:27
Definition Func.hpp:40
void endFor()
Definition Func.cpp:164
void beginFor(Var< T > &IterVar, const Var< T > &InitVar, const Var< T > &UpperBound, const Var< T > &IncVar, const char *File=__builtin_FILE(), int Line=__builtin_LINE())
Definition Func.hpp:309
Definition LoopNest.hpp:12
Var< T > UpperBound
Definition LoopNest.hpp:16
Var< T > Init
Definition LoopNest.hpp:15
Var< T > Inc
Definition LoopNest.hpp:17
Var< T > IterVar
Definition LoopNest.hpp:14
LoopBoundInfo(const Var< T > &IterVar, const Var< T > &Init, const Var< T > &UpperBound, const Var< T > &Inc)
Definition LoopNest.hpp:19
Definition LoopNest.hpp:47
void emit()
Definition LoopNest.hpp:145
LoopNestBuilder(FuncBase &Fn, LoopBuilders... Loops)
Definition LoopNest.hpp:137
LoopNestBuilder & tile(int Tile)
Definition LoopNest.hpp:140
Definition StorageCache.cpp:24
T getRuntimeConstantValue(void *Arg)
Definition CompilerInterfaceRuntimeConstantInfo.h:114
std::enable_if_t< std::is_arithmetic_v< T >, Var< T > > min(const Var< T > &L, const Var< T > &R)
Definition Func.hpp:1478
Definition Hashing.hpp:147
Definition Var.hpp:94