Skip to content

Commit

Permalink
feat: more passes for scatter (#325)
Browse files Browse the repository at this point in the history
* fix: only replace with constants if indices are unique

* feat: gather indices are always unique

* feat: scatter indices are unique pass

* test: unique indices

* chore: run formatter

* feat: reordering associative ops

* feat: more cases + tests

* fix: simplify op generation
  • Loading branch information
avik-pal authored Feb 10, 2025
1 parent cb17aac commit f6d0292
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 7 deletions.
147 changes: 145 additions & 2 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7455,6 +7455,9 @@ struct ScatterUpdateComputationConstProp

LogicalResult matchAndRewrite(stablehlo::ScatterOp op,
PatternRewriter &rewriter) const final {
if (!op.getUniqueIndices())
return failure();

auto &region = op.getUpdateComputation();
auto &block = region.front();

Expand Down Expand Up @@ -7536,6 +7539,139 @@ struct ScatterUpdateComputationConstProp
};
};

struct ScatterIndicesAreUnique : public OpRewritePattern<stablehlo::ScatterOp> {
using OpRewritePattern<stablehlo::ScatterOp>::OpRewritePattern;

LogicalResult matchAndRewrite(stablehlo::ScatterOp op,
PatternRewriter &rewriter) const final {
if (op.getUniqueIndices())
return failure(); // already unique, no need to do anything

auto scatterIndices = op.getScatterIndices();
Attribute scatterIndicesAttr;
if (matchPattern(scatterIndices, m_Constant(&scatterIndicesAttr))) {
auto denseAttr = scatterIndicesAttr.dyn_cast<DenseIntElementsAttr>();

auto shape = scatterIndices.getType().cast<ShapedType>().getShape();
if (shape.empty())
return failure();

int64_t numTuples = 1;
for (int64_t i = 0; i < shape.size() - 1; ++i) {
numTuples *= shape[i];
}
int64_t tupleSize = shape.back();

// Iterate over the scatter indices tensor to extract tuples
SmallVector<SmallVector<int64_t>> indexTuples;
auto values = denseAttr.getValues<APInt>();
auto it = values.begin();
for (int64_t i = 0; i < numTuples; ++i) {
SmallVector<int64_t> indexTuple;
for (int64_t j = 0; j < tupleSize; ++j) {
if (it == values.end()) {
return failure(); // Unexpected end of values
}
indexTuple.push_back((*it).getSExtValue());
++it;
}
indexTuples.push_back(indexTuple);
}

if (areIndexTuplesUnique(indexTuples)) {
auto newOp = rewriter.create<stablehlo::ScatterOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
op.getScatterIndices(), op.getUpdates(),
op.getScatterDimensionNumbers(), op.getIndicesAreSortedAttr(),
rewriter.getBoolAttr(true));
newOp.getUpdateComputation().takeBody(op.getUpdateComputation());
rewriter.replaceOp(op, newOp);
return success();
}
}

return failure();
}

private:
bool areIndexTuplesUnique(
const SmallVector<SmallVector<int64_t>> &indexTuples) const {
bool hasUnique = true;
for (int64_t i = 0; i < indexTuples.size() && hasUnique; ++i) {
for (int64_t j = i + 1; j < indexTuples.size() && hasUnique; ++j) {
if (std::equal(indexTuples[i].begin(), indexTuples[i].end(),
indexTuples[j].begin(), indexTuples[j].end())) {
hasUnique = false;
break;
}
}
}
return hasUnique;
}
};

// This lets us reorder the following
// Case 1: (op x (op (op y x) y)) -> (op (op x y) (op x y))
// Case 2: (op x (op (op x y) y)) -> (op (op x y) (op x y))
// Case 3: (op x (op y (op x y))) -> (op (op x y) (op x y))
// Case 4: (op x (op y (op y x))) -> (op (op x y) (op x y))
template <typename Op>
struct AssociativeBinaryOpReordering : public OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;

LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
auto lhs = op.getLhs();
auto rhsOp = op.getRhs().template getDefiningOp<Op>();
if (!rhsOp)
return failure();

auto rhslhs = rhsOp.getLhs();
auto rhsrhs = rhsOp.getRhs();

auto rhslhsOp = rhslhs.template getDefiningOp<Op>();
if (rhslhsOp) {
auto rhslhslhs = rhslhsOp.getLhs();
auto rhslhsrhs = rhslhsOp.getRhs();

// Case 1
if (lhs == rhslhsrhs && rhslhslhs == rhsrhs) {
rewriter.replaceOpWithNewOp<Op>(op, rhslhsOp.getResult(),
rhslhsOp.getResult());
return success();
}

// Case 2
if (lhs == rhslhslhs && rhslhsrhs == rhsrhs) {
rewriter.replaceOpWithNewOp<Op>(op, rhslhsOp.getResult(),
rhslhsOp.getResult());
return success();
}
}

auto rhsrhsOp = rhsrhs.template getDefiningOp<Op>();
if (rhsrhsOp) {
auto rhsrhslhs = rhsrhsOp.getLhs();
auto rhsrhsrhs = rhsrhsOp.getRhs();

// Case 3
if (lhs == rhsrhslhs && rhslhs == rhsrhsrhs) {
rewriter.replaceOpWithNewOp<Op>(op, rhsrhsOp.getResult(),
rhsrhsOp.getResult());
return success();
}

// Case 4
if (lhs == rhsrhsrhs && rhslhs == rhsrhslhs) {
rewriter.replaceOpWithNewOp<Op>(op, rhsrhsOp.getResult(),
rhsrhsOp.getResult());
return success();
}
}

return failure();
}
};

/////////////// End Imported from stablehlo

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc"
Expand Down Expand Up @@ -7645,7 +7781,13 @@ struct EnzymeHLOOptPass
TransposeUnaryTransposeSimplify<stablehlo::SignOp>,
TransposeUnaryTransposeSimplify<stablehlo::SineOp>,
TransposeUnaryTransposeSimplify<stablehlo::SqrtOp>,
TransposeUnaryTransposeSimplify<stablehlo::TanhOp>>(context);
TransposeUnaryTransposeSimplify<stablehlo::TanhOp>,
AssociativeBinaryOpReordering<stablehlo::AddOp>,
AssociativeBinaryOpReordering<stablehlo::MulOp>,
AssociativeBinaryOpReordering<stablehlo::MinOp>,
AssociativeBinaryOpReordering<stablehlo::MaxOp>,
AssociativeBinaryOpReordering<stablehlo::AndOp>,
AssociativeBinaryOpReordering<stablehlo::OrOp>>(context);

patterns.add<BinopPadToConcat<stablehlo::AddOp>,
BinopPadToConcat<stablehlo::MulOp>, ConcatPad>(context);
Expand Down Expand Up @@ -7767,7 +7909,8 @@ struct EnzymeHLOOptPass
CompareSelectSimplify,
NotSelectSimplify,
CommonCompareExpressionRewrite,
ScatterUpdateComputationConstProp
ScatterUpdateComputationConstProp,
ScatterIndicesAreUnique
>(context);
// clang-format on
patterns.add<SelectOpCanon>(max_constant_expansion, context,
Expand Down
15 changes: 15 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,16 @@ def ApplyBinaryOpTransposeSimplifyRemPatterns : EnzymeHLOPatternOp<
let patterns = ["BinaryOpTransposeSimplify<stablehlo::RemOp>"];
}

def ApplyAssociativeBinaryOpReorderingPatterns : EnzymeHLOPatternOp<
"associative_binary_op_reordering"> {
let patterns = ["AssociativeBinaryOpReordering<stablehlo::AddOp>",
"AssociativeBinaryOpReordering<stablehlo::MulOp>",
"AssociativeBinaryOpReordering<stablehlo::MinOp>",
"AssociativeBinaryOpReordering<stablehlo::MaxOp>",
"AssociativeBinaryOpReordering<stablehlo::AndOp>",
"AssociativeBinaryOpReordering<stablehlo::OrOp>"];
}

def ApplyTransposeUnaryTransposeAbsPatterns : EnzymeHLOPatternOp<
"transpose_unary_transpose_abs"> {
let patterns = ["TransposeUnaryTransposeSimplify<stablehlo::AbsOp>"];
Expand Down Expand Up @@ -863,6 +873,11 @@ def ApplyScatterUpdateComputationConstPropPatterns : EnzymeHLOPatternOp<
let patterns = ["ScatterUpdateComputationConstProp"];
}

def ApplyScatterIndicesAreUniquePatterns : EnzymeHLOPatternOp<
"scatter_indices_are_unique"> {
let patterns = ["ScatterIndicesAreUnique"];
}

// TODO: better naming for parameters requires a static interface for
// constructing them in search.

Expand Down
3 changes: 3 additions & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,14 @@ def hlo_opts():
binary_op_transpose_simplify_or<1>;
binary_op_transpose_simplify_xor<1>;
binary_op_transpose_simplify_rem<1>;
associative_binary_op_reordering<1>;
binop_const_simplify<1>;
compare_select_simplify;
common_compare_expression_rewrite;
not_select_simplify;
scatter_update_computation_const_prop;
scatter_indices_are_unique;
transpose_unary_transpose_abs<1>;
transpose_unary_transpose_neg<1>;
Expand Down
3 changes: 2 additions & 1 deletion test/lit_tests/diffrules/stablehlo/gather.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ module {
// REVERSE-NEXT: %c = stablehlo.constant dense<0> : tensor<45x1xi32>
// REVERSE-NEXT: %0 = "stablehlo.scatter"(%cst, %arg1, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
// REVERSE-NEXT: ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
// REVERSE-NEXT: stablehlo.return %arg4 : tensor<f32>
// REVERSE-NEXT: %1 = stablehlo.add %arg3, %arg4 : tensor<f32>
// REVERSE-NEXT: stablehlo.return %1 : tensor<f32>
// REVERSE-NEXT: }) : (tensor<64x3xf32>, tensor<45x1xi32>, tensor<45x3xf32>) -> tensor<64x3xf32>
// REVERSE-NEXT: return %0, %c : tensor<64x3xf32>, tensor<45x1xi32>
// REVERSE-NEXT: }
53 changes: 53 additions & 0 deletions test/lit_tests/reorderassociative.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s

func.func @main1(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf64> {
%0 = stablehlo.add %arg0, %arg1 : tensor<10xf64>
%1 = stablehlo.add %arg1, %0 : tensor<10xf64>
%2 = stablehlo.add %arg0, %1 : tensor<10xf64>
return %2 : tensor<10xf64>
}

// CHECK: func.func @main1(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf64> {
// CHECK-NEXT: %0 = stablehlo.add %arg0, %arg1 : tensor<10xf64>
// CHECK-NEXT: %1 = stablehlo.add %0, %0 : tensor<10xf64>
// CHECK-NEXT: return %1 : tensor<10xf64>
// CHECK-NEXT: }

func.func @main2(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf64> {
%0 = stablehlo.add %arg1, %arg0 : tensor<10xf64>
%1 = stablehlo.add %0, %arg1 : tensor<10xf64>
%2 = stablehlo.add %arg0, %1 : tensor<10xf64>
return %2 : tensor<10xf64>
}

// CHECK: func.func @main2(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf64> {
// CHECK-NEXT: %0 = stablehlo.add %arg1, %arg0 : tensor<10xf64>
// CHECK-NEXT: %1 = stablehlo.add %0, %0 : tensor<10xf64>
// CHECK-NEXT: return %1 : tensor<10xf64>
// CHECK-NEXT: }

func.func @main3(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf64> {
%0 = stablehlo.multiply %arg1, %arg0 : tensor<10xf64>
%1 = stablehlo.multiply %0, %arg1 : tensor<10xf64>
%2 = stablehlo.multiply %arg0, %1 : tensor<10xf64>
return %2 : tensor<10xf64>
}

// CHECK: func.func @main3(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf64> {
// CHECK-NEXT: %0 = stablehlo.multiply %arg1, %arg0 : tensor<10xf64>
// CHECK-NEXT: %1 = stablehlo.multiply %0, %0 : tensor<10xf64>
// CHECK-NEXT: return %1 : tensor<10xf64>
// CHECK-NEXT: }

func.func @main4(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf64> {
%0 = stablehlo.multiply %arg1, %arg0 : tensor<10xf64>
%1 = stablehlo.multiply %arg0, %0 : tensor<10xf64>
%2 = stablehlo.multiply %arg1, %1 : tensor<10xf64>
return %2 : tensor<10xf64>
}

// CHECK: func.func @main4(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf64> {
// CHECK-NEXT: %0 = stablehlo.multiply %arg1, %arg0 : tensor<10xf64>
// CHECK-NEXT: %1 = stablehlo.multiply %0, %0 : tensor<10xf64>
// CHECK-NEXT: return %1 : tensor<10xf64>
// CHECK-NEXT: }
39 changes: 39 additions & 0 deletions test/lit_tests/scatteruniqueindices.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s

// CHECK-LABEL: func.func @test_scatter_duplicate
func.func @test_scatter_duplicate(%arg0: tensor<4x3xf32>, %arg2: tensor<3x3xf32>) -> tensor<4x3xf32> {
%indices = stablehlo.constant dense<[[0], [2], [0]]> : tensor<3x1xi32>
// CHECK: %{{.+}} = "stablehlo.scatter"(%{{.+}}, %{{.+}}, %{{.+}}) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
%0 = "stablehlo.scatter"(%arg0, %indices, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = stablehlo.multiply %arg3, %arg4 : tensor<f32>
stablehlo.return %1 : tensor<f32>
}) : (tensor<4x3xf32>, tensor<3x1xi32>, tensor<3x3xf32>) -> tensor<4x3xf32>
return %0 : tensor<4x3xf32>
}

// CHECK-LABEL: func.func @test_scatter_unique
func.func @test_scatter_unique(%arg0: tensor<3x3xf32>, %arg2: tensor<2x3xf32>) -> tensor<3x3xf32> {
%indices = stablehlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
// CHECK: %{{.+}} = "stablehlo.scatter"(%{{.+}}, %{{.+}}, %{{.+}}) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
%0 = "stablehlo.scatter"(%arg0, %indices, %arg2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = stablehlo.multiply %arg3, %arg4 : tensor<f32>
stablehlo.return %1 : tensor<f32>
}) : (tensor<3x3xf32>, tensor<2x1xi32>, tensor<2x3xf32>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
}

// CHECK-LABEL: func.func @test_scatter_single
func.func @test_scatter_single(%arg0: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<2x3xf32> {
%indices = stablehlo.constant dense<[[0]]> : tensor<1x1xi32>
%update = stablehlo.constant dense<1.0e+00> : tensor<1x3xf32>
// CHECK: %{{.+}} = "stablehlo.scatter"(%{{.+}}, %{{.+}}, %{{.+}}) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
%0 = "stablehlo.scatter"(%arg0, %indices, %update) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
// CHECK-NOT: stablehlo.multiply
%1 = stablehlo.multiply %arg3, %arg4 : tensor<f32>
stablehlo.return %1 : tensor<f32>
}) : (tensor<2x3xf32>, tensor<1x1xi32>, tensor<1x3xf32>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}
Loading

0 comments on commit f6d0292

Please sign in to comment.