Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions frontends/p4/reassociation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,50 @@ limitations under the License.

#include "reassociation.h"

#include "ir/pattern.h"

namespace P4 {

const IR::Node *Reassociation::reassociate(IR::Operation_Binary *root) {
const auto *right = root->right->to<IR::Constant>();
if (!right) return root;
auto leftBin = root->left->to<IR::Operation_Binary>();
if (!leftBin) return root;
if (leftBin->getStringOp() != root->getStringOp()) return root;
if (!leftBin->right->is<IR::Constant>()) return root;
auto *c = root->clone();
c->left = leftBin->right;
c->right = root->right;
root->left = leftBin->left;
root->right = c;
return root;
void Reassociation::reassociate(IR::Operation_Binary *root) {
LOG3("Trying to reassociate " << root);
// Canonicalize constant to rhs
if (root->left->is<IR::Constant>()) {
std::swap(root->left, root->right);
LOG3("Canonicalized constant to rhs: " << root);
}

/* Match the following tree
* op
* / \
* / \
* op c2
* / \
* / \
* e c1
*
* (note that we're doing postorder visit and we already canonicalized
* constants to rhs)
* Rewrite to:
* op
* / \
* / \
* e op
* / \
* c1 c2
*/
const IR::Operation_Binary *lhs;
const IR::Constant *c1, *c2;
const IR::Expression *e;
if (match(root,
m_BinOp(m_AllOf(m_BinOp(lhs), m_BinOp(m_Expr(e), m_Constant(c1))), m_Constant(c2))) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems far inferior to the IR::Pattern matching, which is much simpler and clearer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you please be more specific? How I can e.g. match if operand has particular type in IR::Pattern? Or check if two operands are the same? etc.

lhs->getStringOp() == root->getStringOp()) {
auto *newRight = root->clone();
newRight->left = c1;
newRight->right = c2;
root->left = e;
root->right = newRight;
LOG3("Reassociated constants together: " << root);
}
}

} // namespace P4
27 changes: 14 additions & 13 deletions frontends/p4/reassociation.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,30 @@ namespace P4 {

using namespace literals;

/** Implements a pass that reorders associative operations when beneficial.
* For example, (a + c0) + c1 is rewritten as a + (c0 + c1) when cs are constants.
/** Implements a pass that reorders associative operations when beneficial. For
* example, (a + c0) + c1 is rewritten as a + (c0 + c1) when cs are constants.
* The pass performs only local reassociation transformation, it does not (yet)
* implement "push to leaves" optimization.
*/
class Reassociation final : public Transform {
class Reassociation final : public Modifier {
public:
Reassociation() {
visitDagOnce = true;
setName("Reassociation");
}
using Transform::postorder;
using Modifier::postorder;

const IR::Node *reassociate(IR::Operation_Binary *root);
void reassociate(IR::Operation_Binary *root);

const IR::Node *postorder(IR::Add *expr) override { return reassociate(expr); }
const IR::Node *postorder(IR::Mul *expr) override { return reassociate(expr); }
const IR::Node *postorder(IR::BOr *expr) override { return reassociate(expr); }
const IR::Node *postorder(IR::BAnd *expr) override { return reassociate(expr); }
const IR::Node *postorder(IR::BXor *expr) override { return reassociate(expr); }
const IR::BlockStatement *preorder(IR::BlockStatement *bs) override {
void postorder(IR::Add *expr) override { reassociate(expr); }
void postorder(IR::Mul *expr) override { reassociate(expr); }
void postorder(IR::BOr *expr) override { reassociate(expr); }
void postorder(IR::BAnd *expr) override { reassociate(expr); }
void postorder(IR::BXor *expr) override { reassociate(expr); }
bool preorder(IR::BlockStatement *bs) override {
// FIXME: Do we need to check for expression, so we'd be able to fine tune, e.g.
// @disable_optimization("reassociation")
if (bs->hasAnnotation(IR::Annotation::disableOptimizationAnnotation)) prune();
return bs;
return !bs->hasAnnotation(IR::Annotation::disableOptimizationAnnotation);
}
};

Expand Down
94 changes: 69 additions & 25 deletions frontends/p4/strengthReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.

#include "strengthReduction.h"

#include "ir/pattern.h"

namespace P4 {

/// @section Helper methods
Expand Down Expand Up @@ -78,9 +80,8 @@ const IR::Node *DoStrengthReduction::postorder(IR::UPlus *expr) { return expr->e
const IR::Node *DoStrengthReduction::postorder(IR::BAnd *expr) {
if (isAllOnes(expr->left)) return expr->right;
if (isAllOnes(expr->right)) return expr->left;
auto l = expr->left->to<IR::Cmpl>();
auto r = expr->right->to<IR::Cmpl>();
if (l && r)
const IR::Cmpl *l, *r;
if (match(expr, m_BinOp(m_Cmpl(l), m_Cmpl(r))))
return new IR::Cmpl(expr->srcInfo, expr->type,
new IR::BOr(expr->srcInfo, expr->type, l->expr, r->expr));

Expand All @@ -94,11 +95,11 @@ const IR::Node *DoStrengthReduction::postorder(IR::BAnd *expr) {
const IR::Node *DoStrengthReduction::postorder(IR::BOr *expr) {
if (isZero(expr->left)) return expr->right;
if (isZero(expr->right)) return expr->left;
auto l = expr->left->to<IR::Cmpl>();
auto r = expr->right->to<IR::Cmpl>();
if (l && r)
const IR::Cmpl *l, *r;
if (match(expr, m_BinOp(m_Cmpl(l), m_Cmpl(r))))
return new IR::Cmpl(expr->srcInfo, expr->type,
new IR::BAnd(expr->srcInfo, expr->type, l->expr, r->expr));

if (hasSideEffects(expr)) return expr;
if (expr->left->equiv(*expr->right)) return expr->left;
return expr;
Expand Down Expand Up @@ -182,10 +183,11 @@ const IR::Node *DoStrengthReduction::postorder(IR::Sub *expr) {
if (isZero(expr->right)) return expr->left;
if (isZero(expr->left)) return new IR::Neg(expr->srcInfo, expr->type, expr->right);
// Replace `a - constant` with `a + (-constant)`
if (enableSubConstToAddTransform && expr->right->is<IR::Constant>()) {
auto cst = expr->right->to<IR::Constant>();
auto neg = new IR::Constant(cst->srcInfo, cst->type, -cst->value, cst->base, true);
auto result = new IR::Add(expr->srcInfo, expr->type, expr->left, neg);
const IR::Constant *cst;
if (enableSubConstToAddTransform && match(expr->right, m_Constant(cst))) {
auto result =
new IR::Add(expr->srcInfo, expr->type, expr->left,
new IR::Constant(cst->srcInfo, cst->type, -cst->value, cst->base, true));
return result;
}
if (hasSideEffects(expr)) return expr;
Expand All @@ -203,12 +205,38 @@ const IR::Node *DoStrengthReduction::postorder(IR::Add *expr) {

const IR::Node *DoStrengthReduction::postorder(IR::Shl *expr) {
if (isZero(expr->right)) return expr->left;
if (const auto *sh2 = expr->left->to<IR::Shl>()) {
if (sh2->right->type->is<IR::Type_InfInt>() && expr->right->type->is<IR::Type_InfInt>()) {
// (a << b) << c is a << (b + c)

{
// (a << b) << c is a << (b + c)
const IR::Expression *a, *b, *c;
if (match(expr, m_BinOp(m_Shl(m_Expr(a), m_AllOf(m_Expr(b), m_TypeInfInt())),
m_AllOf(m_Expr(c), m_TypeInfInt())))) {
auto *result =
new IR::Shl(expr->srcInfo, sh2->left->type, sh2->left,
new IR::Add(expr->srcInfo, sh2->right->type, sh2->right, expr->right));
new IR::Shl(expr->srcInfo, a->type, a, new IR::Add(expr->srcInfo, b->type, b, c));
LOG3("Replace " << expr << " with " << result);
return result;
}
}

{
// (a >> b) << b could be transformed into a & (-1 << b)
const IR::Expression *a, *b;
const IR::Type_Bits *type;
if (match(expr, m_AllOf(m_BinOp(m_Shr(m_Expr(a), m_Expr(b)), m_DeferredEq(b)),
m_TypeBits(type)))) {
big_int mask;
if (const auto *constRhs = b->to<IR::Constant>()) {
// Explicitly calculate mask to silence `value does not fit` warning
int maskBits = type->width_bits() - constRhs->asInt();
mask = Util::mask(maskBits > 0 ? maskBits : 0);
} else {
mask = Util::mask(type->width_bits());
}

auto *result =
new IR::BAnd(expr->srcInfo, a->type, a,
new IR::Shl(expr->srcInfo, a->type,
new IR::Constant(expr->srcInfo, a->type, mask), b));
LOG3("Replace " << expr << " with " << result);
return result;
}
Expand All @@ -220,16 +248,35 @@ const IR::Node *DoStrengthReduction::postorder(IR::Shl *expr) {

const IR::Node *DoStrengthReduction::postorder(IR::Shr *expr) {
if (isZero(expr->right)) return expr->left;
if (auto sh2 = expr->left->to<IR::Shr>()) {
if (sh2->right->type->is<IR::Type_InfInt>() && expr->right->type->is<IR::Type_InfInt>()) {
// (a >> b) >> c is a >> (b + c)

{ // (a << b) << c is a << (b + c)
const IR::Expression *a, *b, *c;
if (match(expr, m_BinOp(m_Shr(m_Expr(a), m_AllOf(m_Expr(b), m_TypeInfInt())),
m_AllOf(m_Expr(c), m_TypeInfInt())))) {
auto *result =
new IR::Shr(expr->srcInfo, sh2->left->type, sh2->left,
new IR::Add(expr->srcInfo, sh2->right->type, sh2->right, expr->right));
new IR::Shr(expr->srcInfo, a->type, a, new IR::Add(expr->srcInfo, b->type, b, c));
LOG3("Replace " << expr << " with " << result);
return result;
}
}

{
// (a << b) >> b could be transformed into a & (-1 >> b) if the shift is logical
const IR::Expression *a, *b;
const IR::Type_Bits *type;
if (match(expr, m_AllOf(m_BinOp(m_Shl(m_Expr(a), m_Expr(b)), m_DeferredEq(b)),
m_TypeBits(type))) &&
!type->isSigned) {
auto *result = new IR::BAnd(
expr->srcInfo, a->type, a,
new IR::Shr(
expr->srcInfo, a->type,
new IR::Constant(expr->srcInfo, a->type, Util::mask(type->width_bits())), b));
LOG3("Replace " << expr << " with " << result);
return result;
}
}

if (!hasSideEffects(expr->right) && isZero(expr->left)) return expr->left;
return expr;
}
Expand Down Expand Up @@ -291,11 +338,8 @@ const IR::Node *DoStrengthReduction::postorder(IR::Mod *expr) {

const IR::Node *DoStrengthReduction::postorder(IR::Range *range) {
// Range a..a is the same as a
if (auto c0 = range->left->to<IR::Constant>()) {
if (auto c1 = range->right->to<IR::Constant>()) {
if (c0->value == c1->value) return c0;
}
}
const IR::Constant *c0, *c1;
if (match(range, m_BinOp(m_Constant(c0), m_Constant(c1))) && c0->value == c1->value) return c0;
return range;
}

Expand Down
Loading
Loading