Skip to content

Commit ec3f4c9

Browse files
authored
Add statement splitter utility (#5158)
Signed-off-by: Vladimír Štill <[email protected]>
1 parent 126f849 commit ec3f4c9

File tree

5 files changed

+822
-0
lines changed

5 files changed

+822
-0
lines changed

ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set (IR_SRCS
3030
node.cpp
3131
pass_manager.cpp
3232
pass_utils.cpp
33+
splitter.cpp
3334
type.cpp
3435
visitor.cpp
3536
write_context.cpp
@@ -56,6 +57,7 @@ set (IR_HDRS
5657
nodemap.h
5758
pass_manager.h
5859
pass_utils.h
60+
splitter.h
5961
vector.h
6062
visitor.h
6163
)

ir/splitter.cpp

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
/*
2+
Copyright 2025-present Altera Corporation.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
#include "splitter.h"
18+
19+
#include <utility>
20+
#include <vector>
21+
22+
#include "frontends/common/resolveReferences/referenceMap.h"
23+
#include "frontends/common/resolveReferences/resolveReferences.h"
24+
#include "frontends/p4/typeMap.h"
25+
#include "ir/ir-traversal.h"
26+
#include "ir/visitor.h"
27+
28+
namespace P4 {
29+
30+
struct StatementSplitter : Inspector, ResolutionContext {
31+
StatementSplitter(
32+
std::function<bool(const IR::Statement *, const Visitor::Context *)> predicate,
33+
P4::NameGenerator &nameGen, P4::TypeMap *typeMap,
34+
absl::flat_hash_set<P4::cstring, Util::Hash> &neededDecls)
35+
: predicate(predicate), nameGen(nameGen), typeMap(typeMap), neededDecls(neededDecls) {}
36+
37+
bool preorder(const IR::LoopStatement *) override {
38+
BUG("Loops not supported in statement splitter, must be unrolled before");
39+
}
40+
41+
bool preorder(const IR::Statement *stmt) override {
42+
handleStmt(stmt);
43+
return false;
44+
}
45+
46+
bool preorder(const IR::BlockStatement *bs) override {
47+
if (handleStmt(bs)) {
48+
// split on the bs itself
49+
return false;
50+
}
51+
52+
for (size_t i = 0, sz = bs->components.size(); i < sz; i++) {
53+
visit(bs->components[i], "vector");
54+
if (result.after) {
55+
const auto [before, after, _] = result; // copy
56+
auto *copy = bs->clone();
57+
copy->components.erase(copy->components.begin() + i, copy->components.end());
58+
if (before) {
59+
copy->components.push_back(before);
60+
}
61+
result.before = filterDeclarations(copy);
62+
copy = bs->clone();
63+
copy->components.erase(copy->components.begin(), copy->components.begin() + i);
64+
collectNeededDeclarations(copy);
65+
copy->components.replace(copy->components.begin(), after);
66+
result.after = copy;
67+
return false; // stop on first split point
68+
}
69+
}
70+
return false;
71+
}
72+
73+
bool preorder(const IR::IfStatement *ifs) override {
74+
if (handleStmt(ifs)) {
75+
return false; // split on the if itself
76+
}
77+
78+
auto [results, anySplit] = splitBranches({ifs->ifTrue, ifs->ifFalse});
79+
if (!anySplit) {
80+
return false;
81+
}
82+
83+
IR::ID condName{nameGen.newName("cond"), nullptr};
84+
const auto &si = ifs->srcInfo;
85+
const auto *decl = new IR::Declaration_Variable(si, condName, IR::Type::Boolean::get());
86+
result.hoistedDeclarations.push_back(decl);
87+
88+
const auto *condPE = new IR::PathExpression(si, new IR::Path(si, condName));
89+
const auto *asgn = new IR::AssignmentStatement(si, condPE, ifs->condition);
90+
91+
auto *beforeIf = ifs->clone();
92+
beforeIf->condition = condPE->clone();
93+
beforeIf->ifTrue = results[0].before;
94+
beforeIf->ifFalse = results[1].before;
95+
result.before = new IR::BlockStatement(si, {asgn, beforeIf});
96+
97+
auto *afterIf = beforeIf->clone();
98+
afterIf->ifTrue = results[0].after;
99+
afterIf->ifFalse = results[1].after;
100+
result.after = afterIf;
101+
102+
for (auto **trueBranch : {&beforeIf->ifTrue, &afterIf->ifTrue}) {
103+
if (*trueBranch == nullptr) {
104+
*trueBranch = new IR::BlockStatement(ifs->ifTrue->srcInfo);
105+
}
106+
}
107+
return false;
108+
}
109+
110+
bool preorder(const IR::SwitchStatement *sw) override {
111+
if (handleStmt(sw)) {
112+
return false; // split on the switch itself
113+
}
114+
115+
std::vector<const IR::Statement *> branches;
116+
for (const auto *case_ : sw->cases) {
117+
branches.push_back(case_->statement);
118+
}
119+
auto [results, anySplit] = splitBranches(branches);
120+
121+
if (!anySplit) {
122+
return false;
123+
}
124+
125+
IR::ID selName{nameGen.newName("selector"), nullptr};
126+
const auto &si = sw->srcInfo;
127+
const auto *selType = typeMap ? typeMap->getType(sw->expression) : nullptr;
128+
selType = selType ? selType : sw->expression->type;
129+
BUG_CHECK(selType && !selType->is<IR::Type::Unknown>(),
130+
"Cannot split switch statement with unknown selector type %1%", sw->expression);
131+
const auto *decl = new IR::Declaration_Variable(si, selName, selType);
132+
result.hoistedDeclarations.push_back(decl);
133+
134+
const auto *selPE = new IR::PathExpression(si, new IR::Path(si, selName));
135+
const auto *asgn = new IR::AssignmentStatement(si, selPE, sw->expression);
136+
137+
// ensure we don't accidentally create fallthrough
138+
for (size_t i = 0; i < branches.size(); ++i) {
139+
for (const auto **val : {&results[i].before, &results[i].after}) {
140+
if (!*val && branches[i]) {
141+
*val = new IR::BlockStatement(branches[i]->srcInfo);
142+
}
143+
}
144+
}
145+
146+
auto *beforeSw = sw->clone();
147+
beforeSw->expression = selPE;
148+
for (size_t i = 0; i < branches.size(); ++i) {
149+
setCase(beforeSw, i, results[i].before);
150+
}
151+
result.before = new IR::BlockStatement(si, {asgn, beforeSw});
152+
153+
auto *afterSw = beforeSw->clone();
154+
for (size_t i = 0; i < branches.size(); ++i) {
155+
setCase(afterSw, i, results[i].after);
156+
}
157+
result.after = afterSw;
158+
return false;
159+
}
160+
161+
void end_apply(const IR::Node *root) override {
162+
if (!result.before) {
163+
result.before = root->checkedTo<IR::Statement>();
164+
}
165+
}
166+
167+
SplitResult<IR::Statement> result;
168+
169+
private:
170+
bool handleStmt(const IR::Statement *stmt) {
171+
BUG_CHECK(result.before == nullptr && result.after == nullptr,
172+
"More than one leaf statement found: %1% and %2%",
173+
result.before ? result.before : result.after, stmt);
174+
if (predicate(stmt, getChildContext())) {
175+
result.after = stmt;
176+
collectNeededDeclarations(stmt);
177+
return true;
178+
}
179+
return false;
180+
}
181+
182+
void setCase(IR::SwitchStatement *sw, size_t i, const IR::Statement *value) {
183+
// note that we can't go all the way to statement as it can be nullptr
184+
modify(sw, &IR::SwitchStatement::cases, IR::Traversal::Index(i),
185+
[value](IR::SwitchCase *case_) {
186+
case_->statement = value;
187+
return case_;
188+
});
189+
}
190+
191+
void takeHoisted(std::vector<const IR::Declaration *> &decls) {
192+
result.hoistedDeclarations.insert(result.hoistedDeclarations.end(), decls.begin(),
193+
decls.end());
194+
decls.clear();
195+
}
196+
197+
std::pair<std::vector<SplitResult<IR::Statement>>, bool> splitBranches(
198+
std::vector<const IR::Statement *> branches) {
199+
std::vector<SplitResult<IR::Statement>> res;
200+
bool anySplit = false;
201+
res.reserve(branches.size());
202+
203+
for (const auto *branch : branches) {
204+
if (!branch) {
205+
res.emplace_back();
206+
continue;
207+
}
208+
visit(branch, "branch");
209+
anySplit = anySplit || result.after;
210+
if (!result) {
211+
result.before = branch;
212+
}
213+
res.emplace_back(std::move(result));
214+
result.clear();
215+
}
216+
for (auto &[_, __, hoisted] : res) {
217+
takeHoisted(hoisted);
218+
}
219+
return {res, anySplit};
220+
}
221+
222+
void collectNeededDeclarations(const IR::Node *after) {
223+
struct CollectNeededDecls : Inspector, ResolutionContext {
224+
explicit CollectNeededDecls(absl::flat_hash_set<P4::cstring, Util::Hash> &needed)
225+
: needed(needed) {}
226+
227+
void postorder(const IR::PathExpression *pe) override {
228+
// using lower-level resolution to avoid emitting errors for things not found
229+
if (!resolve(pe->path->name, ResolutionType::Any).empty()) {
230+
needed.insert(pe->path->name);
231+
}
232+
}
233+
234+
absl::flat_hash_set<P4::cstring, Util::Hash> &needed;
235+
};
236+
237+
after->apply(CollectNeededDecls(neededDecls), getChildContext());
238+
}
239+
240+
template <typename T>
241+
const T *filterDeclarations(const T *node) {
242+
struct FilterDecls : Transform {
243+
FilterDecls(absl::flat_hash_set<P4::cstring, Util::Hash> &needed,
244+
std::vector<const IR::Declaration *> &hoisted)
245+
: needed(needed), hoisted(hoisted) {}
246+
247+
const IR::Node *preorder(IR::Declaration_Variable *decl) override {
248+
if (needed.contains(decl->name)) {
249+
hoisted.push_back(decl);
250+
return nullptr;
251+
}
252+
return decl;
253+
}
254+
255+
absl::flat_hash_set<P4::cstring, Util::Hash> &needed;
256+
std::vector<const IR::Declaration *> &hoisted;
257+
};
258+
259+
FilterDecls filter(neededDecls, result.hoistedDeclarations);
260+
return node->apply(filter)->template checkedTo<T>();
261+
}
262+
263+
std::function<bool(const IR::Statement *, const Visitor::Context *)> predicate;
264+
P4::NameGenerator &nameGen;
265+
P4::TypeMap *typeMap;
266+
absl::flat_hash_set<P4::cstring, Util::Hash> &neededDecls;
267+
};
268+
269+
SplitResult<IR::Statement> splitStatementBefore(
270+
const IR::Statement *stat,
271+
std::function<bool(const IR::Statement *, const P4::Visitor_Context *)> predicate,
272+
P4::NameGenerator &nameGen, P4::TypeMap *typeMap) {
273+
absl::flat_hash_set<P4::cstring, Util::Hash> neededDecls;
274+
StatementSplitter split(predicate, nameGen, typeMap, neededDecls);
275+
// no incoming context, declaration resolution will work only within the splitter
276+
stat->apply(split);
277+
return split.result;
278+
}
279+
280+
} // namespace P4

0 commit comments

Comments
 (0)