Skip to content

Commit

Permalink
More pattern matching in SimplifyingIrBuilder::addExpr (#3538)
Browse files Browse the repository at this point in the history
- If the simplification result is 1 or 0, use the cached IR node from
fusion.
- Simplify `x + y - y` into `x`
  • Loading branch information
zasdfgbnm authored Dec 6, 2024
1 parent 76483fe commit 9346c8f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 45 deletions.
26 changes: 20 additions & 6 deletions csrc/ir/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,19 +340,33 @@ Val* SimplifyingIrBuilder::addExpr(
rhs_dtype = getDataType(rhs);
}
if (lhs == nullptr) {
return IrBuilder::IrBuilder::create<Val>(rhs, rhs_dtype);
return IrBuilder::create<Val>(rhs, rhs_dtype);
}
// simplify x + y - y as x
if (lhs->definition() != nullptr && lhs->definition()->isA<BinaryOp>()) {
auto binary_op = lhs->definition()->as<BinaryOp>();
if (binary_op->getBinaryOpType() == BinaryOpType::Add) {
if (binary_op->rhs()->isConst() &&
(bool)(binary_op->rhs()->value() == -rhs)) {
return binary_op->lhs();
}
}
}
auto target_dtype = promoteType(lhs->dtype(), rhs_dtype);
if (rhs == 0) {
return maybeCastExpr(target_dtype, lhs);
} else if (lhs->isConst()) {
return IrBuilder::IrBuilder::create<Val>(lhs->value() + rhs, target_dtype);
auto result = lhs->value() + rhs;
if (result == 0) {
return lhs->container()->zeroVal(target_dtype);
} else if (result == 1) {
return lhs->container()->oneVal(target_dtype);
}
return IrBuilder::create<Val>(lhs->value() + rhs, target_dtype);
} else if (rhs > 0) {
return IrBuilder::addExpr(
lhs, IrBuilder::IrBuilder::create<Val>(rhs, rhs_dtype));
return IrBuilder::addExpr(lhs, IrBuilder::create<Val>(rhs, rhs_dtype));
} else {
return IrBuilder::subExpr(
lhs, IrBuilder::IrBuilder::create<Val>(-rhs, rhs_dtype));
return IrBuilder::subExpr(lhs, IrBuilder::create<Val>(-rhs, rhs_dtype));
}
}

Expand Down
62 changes: 23 additions & 39 deletions tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,13 +1382,11 @@ TEST_F(IndexingTest, SimpleVectorize) {
// vectorized domain with zero, which doesn't go through the
// simplification of SimplifyingIrBuilder. We could use simplifyExpr,
// but for the sake of testing, just use IrBuilder::addExpr.
return IrBuilder::addExpr(
mulExpr(
addExpr(
mulExpr(loop_indices.at(0), tv->axis(1)->extent()),
loop_indices.at(1)),
tv->axis(2)->extent()),
tv->fusion()->zeroVal());
return mulExpr(
addExpr(
mulExpr(loop_indices.at(0), tv->axis(1)->extent()),
loop_indices.at(1)),
tv->axis(2)->extent());
case 1:
return tv->fusion()->zeroVal();
default:
Expand Down Expand Up @@ -2896,13 +2894,11 @@ TEST_F(PredicateIndexingTest, SimpleVectorize) {
Val* getInlinePredicate(TensorView* tv) const override {
std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_);

auto start_idx = IrBuilder::addExpr(
mulExpr(
addExpr(
mulExpr(loop_indices.at(0), tv->axis(1)->extent()),
loop_indices.at(1)),
tv->axis(2)->extent()),
tv->fusion()->zeroVal());
auto start_idx = mulExpr(
addExpr(
mulExpr(loop_indices.at(0), tv->axis(1)->extent()),
loop_indices.at(1)),
tv->axis(2)->extent());
auto stop_idx = addExpr(
mulExpr(
addExpr(
Expand All @@ -2911,7 +2907,7 @@ TEST_F(PredicateIndexingTest, SimpleVectorize) {
tv->axis(2)->extent()),
subExpr(tv->axis(2)->extent(), createInt(1)));

// ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + 0 )>= 0 ) &&
// ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) >= 0 ) &&
// ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + 3 ) < ( (( ((
// getMetaData(T0) )).logical_size ))[0] ) ) )
return andExpr(
Expand Down Expand Up @@ -2982,8 +2978,7 @@ TEST_F(PredicateIndexingTest, NonInnermostVectorize) {
loop_indices.at(1)),
tv->axis(3)->extent()),
loop_indices.at(3));
auto start_idx = IrBuilder::addExpr(
mulExpr(common_idx, tv->axis(2)->extent()), tv->fusion()->zeroVal());
auto start_idx = mulExpr(common_idx, tv->axis(2)->extent());
auto stop_idx = addExpr(
mulExpr(common_idx, tv->axis(2)->extent()),
subExpr(tv->axis(2)->extent(), createInt(1)));
Expand Down Expand Up @@ -3406,20 +3401,15 @@ TEST_F(PredicateIndexingTest, UnswitchedCircularBuffering1) {
Val* getOuterPredicate(TensorView* tv) const override {
std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_);

auto zero = tv->fusion()->zeroVal();

// The base index is:
//
// i0 * 4 + i2
//
// where i2 is the circular buffer index. The index of iUS10 is
// not included as its extent is 1.

// NOTE: Expression Simplification is disabled in PredicateIndexValidator,
// so trivial addition appears in the expression.
// Start index: i0 * 4 + 0
Val* start_idx = IrBuilder::addExpr(
IrBuilder::mulExpr(loop_indices.at(0), createInt(4)), createInt(0));
// Start index: i0 * 4
Val* start_idx = IrBuilder::mulExpr(loop_indices.at(0), createInt(4));

// Stop index: i0 * 4 + 4
// Note that it isn't "i0 * 4 + 3" since i2 is circular buffered
Expand All @@ -3429,7 +3419,7 @@ TEST_F(PredicateIndexingTest, UnswitchedCircularBuffering1) {
IrBuilder::mulExpr(loop_indices.at(0), createInt(4)), createInt(4));

return andExpr(
geExpr(start_idx, zero),
geExpr(start_idx, tv->fusion()->zeroVal()),
ltExpr(stop_idx, tv->getLogicalDomain().at(0)->extent()));
}
};
Expand Down Expand Up @@ -3502,13 +3492,10 @@ TEST_F(PredicateIndexingTest, UnswitchedCircularBuffering2) {
// to the vectorization. Since it's vectorized, the predicate
// uses 0 for start and (vec_factor - 1) for stop

// Start index: (i0 * 128 + 0) * 4 + 0
Val* start_idx = IrBuilder::addExpr(
mulExpr(
IrBuilder::addExpr(
mulExpr(loop_indices.at(0), createInt(128)), zero),
createInt(4)),
zero);
// Start index: (i0 * 128 + 0) * 4
Val* start_idx = mulExpr(
IrBuilder::addExpr(mulExpr(loop_indices.at(0), createInt(128)), zero),
createInt(4));
// Stop index: (i0 * 128 + 129) * 4 + 3
Val* stop_idx = addExpr(
mulExpr(
Expand Down Expand Up @@ -3607,13 +3594,10 @@ TEST_P(PredicateIndexingTest, UnswitchedCircularBuffering3) {
// to the vectorization. Since it's vectorized, the predicate
// uses 0 for start and (vec_factor - 1) for stop

// Start index: (i0 * 128 + 0) * 4 + 0
Val* start_idx = IrBuilder::addExpr(
mulExpr(
IrBuilder::addExpr(
mulExpr(loop_indices.at(0), createInt(128)), zero),
createInt(4)),
zero);
// Start index: (i0 * 128 + 0) * 4
Val* start_idx = mulExpr(
IrBuilder::addExpr(mulExpr(loop_indices.at(0), createInt(128)), zero),
createInt(4));
// Stop index: (i0 * 128 + 129) * 4 + 3
Val* stop_idx = addExpr(
mulExpr(
Expand Down

0 comments on commit 9346c8f

Please sign in to comment.