Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rewrite constants into scatter update_computation #319

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 89 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7643,6 +7643,93 @@ struct CommonCompareExpressionRewrite
}
};

struct ScatterUpdateComputationConstProp
: public OpRewritePattern<stablehlo::ScatterOp> {
using OpRewritePattern<stablehlo::ScatterOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::ScatterOp op,
PatternRewriter &rewriter) const final {
auto &region = op.getUpdateComputation();
auto &block = region.front();

// Check all inputs are constant and splat and their values are the same.
auto [constInput, inputSplatAttr] =
isConstantSplatValueRange(op.getInputs());

// Check all updates are constant and splat and their values are the same.
auto [constUpdate, updateSplatAttr] =
isConstantSplatValueRange(op.getUpdates());

if (constInput || constUpdate) {
bool inputTransformed = false;
bool updateTransformed = false;
auto blockArgInput = block.getArgument(0);
auto blockArgUpdate = block.getArgument(1);

if (constInput && !blockArgInput.getUses().empty()) {
inputTransformed = true;
auto denseAttr = DenseElementsAttr::get(
blockArgInput.getType().cast<ShapedType>(), inputSplatAttr);
auto constInputOp =
rewriter.create<stablehlo::ConstantOp>(op.getLoc(), denseAttr);
blockArgInput.replaceAllUsesWith(constInputOp);
}

if (constUpdate && !blockArgUpdate.getUses().empty()) {
updateTransformed = true;
auto denseAttr = DenseElementsAttr::get(
blockArgUpdate.getType().cast<ShapedType>(), updateSplatAttr);
auto constUpdateOp =
rewriter.create<stablehlo::ConstantOp>(op.getLoc(), denseAttr);
blockArgUpdate.replaceAllUsesWith(constUpdateOp);
}

if (!inputTransformed && !updateTransformed)
return failure();

auto newOp = rewriter.create<stablehlo::ScatterOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
op.getScatterIndices(), op.getUpdates(),
op.getScatterDimensionNumbers(), op.getIndicesAreSorted(),
op.getUniqueIndices());
newOp.getUpdateComputation().takeBody(region);
rewriter.replaceOp(op, newOp);

return success();
}

return failure();
}

private:
std::tuple<bool, Attribute>
isConstantSplatValueRange(ValueRange range) const {
Attribute splatAttr = nullptr;
bool isConstant = true;
for (auto val : range) {
DenseElementsAttr attr;
if (matchPattern(val, m_Constant(&attr))) {
if (attr.isSplat()) {
if (!splatAttr) {
splatAttr = attr.getSplatValue<Attribute>();
continue;
} else if (splatAttr != attr.getSplatValue<Attribute>()) {
isConstant = false;
break;
}
} else {
isConstant = false;
break;
}
} else {
isConstant = false;
break;
}
}
return std::make_tuple(isConstant, splatAttr);
};
};

/////////////// End Imported from stablehlo

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc"
Expand Down Expand Up @@ -7873,7 +7960,8 @@ struct EnzymeHLOOptPass
ZeroExtentTensorCanon,
CompareSelectSimplify,
NotSelectSimplify,
CommonCompareExpressionRewrite
CommonCompareExpressionRewrite,
ScatterUpdateComputationConstProp
>(context);
// clang-format on
patterns.add<SelectOpCanon>(max_constant_expansion, context,
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,11 @@ def CommonCompareExpressionRewritePatterns : EnzymeHLOPatternOp<
let patterns = ["CommonCompareExpressionRewrite"];
}

def ApplyScatterUpdateComputationConstPropPatterns : EnzymeHLOPatternOp<
"scatter_update_computation_const_prop"> {
let patterns = ["ScatterUpdateComputationConstProp"];
}

// TODO: better naming for parameters requires a static interface for
// constructing them in search.

Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def hlo_opts():
compare_select_simplify;
common_compare_expression_rewrite;
not_select_simplify;
scatter_update_computation_const_prop;

transpose_unary_transpose_abs<1>;
transpose_unary_transpose_neg<1>;
Expand Down
9 changes: 3 additions & 6 deletions test/lit_tests/diffrules/stablehlo/gather.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ module {
// REVERSE-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x3xf32>
// REVERSE-NEXT: %c = stablehlo.constant dense<0> : tensor<45x1xi32>
// REVERSE-NEXT: %0 = "stablehlo.scatter"(%cst, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
// REVERSE-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
// REVERSE-NEXT: %1 = stablehlo.add %arg3, %arg4 : tensor<f32>
// REVERSE-NEXT: stablehlo.return %1 : tensor<f32>
// REVERSE-NEXT: }) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
// REVERSE-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
// REVERSE-NEXT: stablehlo.return %arg4 : tensor<f32>
// REVERSE-NEXT: }) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
// REVERSE-NEXT: return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
// REVERSE-NEXT: }


46 changes: 46 additions & 0 deletions test/lit_tests/scatterupdatecomputationconstprop.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

func.func @main1(%arg0: tensor<64x3xf32>, %arg1: tensor<45x1xi32>, %arg2: tensor<45x3xf32>) -> (tensor<64x3xf32>, tensor<45x1xi32>) {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<64x3xf32>
%c = stablehlo.constant dense<0> : tensor<45x1xi32>
%0 = "stablehlo.scatter"(%cst, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = stablehlo.add %arg3, %arg4 : tensor<f32>
stablehlo.return %1 : tensor<f32>
}) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
}

// CHECK: func.func @main1(%arg0: tensor<64x3xf32>, %arg1: tensor<45x1xi32>, %arg2: tensor<45x3xf32>) -> (tensor<64x3xf32>, tensor<45x1xi32>) {
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<64x3xf32>
// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor<45x1xi32>
// CHECK-NEXT: %0 = "stablehlo.scatter"(%cst, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
// CHECK-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
// CHECK-NEXT: stablehlo.return %arg4 : tensor<f32>
// CHECK-NEXT: }) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
// CHECK-NEXT: return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
// CHECK-NEXT: }

func.func @main2(%arg0: tensor<64x3xf32>, %arg1: tensor<45x1xi32>) -> (tensor<64x3xf32>, tensor<45x1xi32>) {
%cst = stablehlo.constant dense<2.000000e+00> : tensor<64x3xf32>
%c = stablehlo.constant dense<0> : tensor<45x1xi32>
%cst_2 = stablehlo.constant dense<5.000000e+00> : tensor<45x3xf32>
%0 = "stablehlo.scatter"(%cst, %arg1, %cst_2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = stablehlo.multiply %arg3, %arg4 : tensor<f32>
stablehlo.return %1 : tensor<f32>
}) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
}

// CHECK: func.func @main2(%arg0: tensor<64x3xf32>, %arg1: tensor<45x1xi32>) -> (tensor<64x3xf32>, tensor<45x1xi32>) {
// CHECK-NEXT: %cst = stablehlo.constant dense<1.000000e+01> : tensor<f32>
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<64x3xf32>
// CHECK-NEXT: %c = stablehlo.constant dense<0> : tensor<45x1xi32>
// CHECK-NEXT: %cst_1 = stablehlo.constant dense<5.000000e+00> : tensor<45x3xf32>
// CHECK-NEXT: %0 = "stablehlo.scatter"(%cst_0, %arg1, %cst_1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
// CHECK-NEXT: stablehlo.return %cst : tensor<f32>
// CHECK-NEXT: }) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
// CHECK-NEXT: return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
// CHECK-NEXT: }
Loading