Skip to content

Commit

Permalink
feat: mul sign simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 11, 2025
1 parent d1517cc commit 75a3bab
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
57 changes: 54 additions & 3 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7964,8 +7964,58 @@ struct MultiplyNegateSimplify : public OpRewritePattern<stablehlo::MulOp> {
}
};

// (mul (sign x) (add (abs x) y)) -> (add x (mul y (sign x)))
// (mul (sign x) (add y (abs x))) -> (add (mul y (sign x)) x)
// This pattern only does partially the following. We rely on transforming the op to a
// pattern which further uses the above pattern.
// (mul (sign x) (add (abs x) (abs x))) -> (mul x x)
// TODO: We can simplify for cases where only one of the add operands is abs.
struct MultiplySignAddSimplify : public OpRewritePattern<stablehlo::MulOp> {
using OpRewritePattern<stablehlo::MulOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::MulOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.getOperand(0);
auto rhs = op.getOperand(1);

stablehlo::SignOp signOp = nullptr;
stablehlo::AddOp addOp = nullptr;
if (lhs.getDefiningOp<stablehlo::SignOp>()) {
signOp = lhs.getDefiningOp<stablehlo::SignOp>();
if (rhs.getDefiningOp<stablehlo::AddOp>()) {
addOp = rhs.getDefiningOp<stablehlo::AddOp>();
} else {
return failure();
}
} else if (rhs.getDefiningOp<stablehlo::SignOp>()) {
signOp = rhs.getDefiningOp<stablehlo::SignOp>();
if (lhs.getDefiningOp<stablehlo::AddOp>()) {
addOp = lhs.getDefiningOp<stablehlo::AddOp>();
} else {
return failure();
}
} else {
return failure();
}

auto signOperand = signOp.getOperand();

auto lhsAddOp = addOp.getOperand(0);
auto rhsAddOp = addOp.getOperand(1);

if (lhsAddOp != rhsAddOp)
return failure(); // TODO: Can support more cases.

auto lhsAddAbsOp = lhsAddOp.getDefiningOp<stablehlo::AbsOp>();
auto rhsAddAbsOp = rhsAddOp.getDefiningOp<stablehlo::AbsOp>();
if (!lhsAddAbsOp || !rhsAddAbsOp)
return failure();

if (signOperand != lhsAddAbsOp.getOperand() || signOperand != rhsAddAbsOp.getOperand())
return failure();

rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, signOperand, signOperand);
return success();
}
};

/////////////// End Imported from stablehlo

Expand Down Expand Up @@ -8210,7 +8260,8 @@ struct EnzymeHLOOptPass
TransposeReduceSimplify,
SignAbsSimplify,
PositiveNegativeSelectSimplify,
MultiplyNegateSimplify
MultiplyNegateSimplify,
MultiplySignAddSimplify
>(context);
// clang-format on
patterns.add<SelectOpCanon>(max_constant_expansion, context,
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,11 @@ def ApplyMultiplyNegateSimplifyPatterns : EnzymeHLOPatternOp<
let patterns = ["MultiplyNegateSimplify"];
}

def ApplyMultiplySignAddSimplifyPatterns : EnzymeHLOPatternOp<
"multiply_sign_add_simplify"> {
let patterns = ["MultiplySignAddSimplify"];
}

// TODO: better naming for parameters requires a static interface for
// constructing them in search.

Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def hlo_opts():
sign_abs_simplify;
positive_negative_select_simplify;
multiply_negate_simplify;
multiply_sign_add_simplify;
transpose_unary_transpose_abs<1>;
transpose_unary_transpose_neg<1>;
Expand Down

0 comments on commit 75a3bab

Please sign in to comment.