|
| 1 | +From fb0378d09cebb74da6ca253f6b41241a26bab43e Mon Sep 17 00:00:00 2001 |
| 2 | +From: Christopher Bate < [email protected]> |
| 3 | +Date: Wed, 27 Nov 2024 00:10:11 +0000 |
| 4 | +Subject: [PATCH] Fix a couple missing checks for static shapes in |
| 5 | + `stablehlo-aggressive-folder` |
| 6 | + |
| 7 | +--- |
| 8 | + .../stablehlo_aggressive_folder.mlir | 27 +++++++++++++------ |
| 9 | + .../transforms/StablehloAggressiveFolder.cpp | 9 +++++++ |
| 10 | + 2 files changed, 28 insertions(+), 8 deletions(-) |
| 11 | + |
| 12 | +diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir |
| 13 | +index 5b21a10d..c90c89c6 100644 |
| 14 | +--- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir |
| 15 | ++++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir |
| 16 | +@@ -4,14 +4,17 @@ |
| 17 | + // AddOp |
| 18 | + |
| 19 | + // CHECK-LABEL: @add_fold_cst |
| 20 | +-func.func @add_fold_cst() -> (tensor<i32>, tensor<f32>) { |
| 21 | ++func.func @add_fold_cst() -> (tensor<i32>, tensor<f32>, tensor<?xf32>) { |
| 22 | + %cst = stablehlo.constant dense<1> : tensor<i32> |
| 23 | + %cst_1 = stablehlo.constant dense<1.0> : tensor<f32> |
| 24 | ++ %cst_2 = stablehlo.constant dense<2.0> : tensor<1xf32> |
| 25 | + // CHECK: stablehlo.constant dense<2> : tensor<i32> |
| 26 | + // CHECK: stablehlo.constant dense<2.0{{.*}}> : tensor<f32> |
| 27 | ++ // CHECK: stablehlo.add |
| 28 | + %0 = stablehlo.add %cst, %cst : tensor<i32> |
| 29 | + %1 = stablehlo.add %cst_1, %cst_1 : tensor<f32> |
| 30 | +- return %0, %1 : tensor<i32>, tensor<f32> |
| 31 | ++ %2 = stablehlo.add %cst_2, %cst_2 : (tensor<1xf32>, tensor<1xf32>) -> tensor<?xf32> |
| 32 | ++ return %0, %1, %2 : tensor<i32>, tensor<f32>, tensor<?xf32> |
| 33 | + } |
| 34 | + |
| 35 | + // ----- |
| 36 | +@@ -106,14 +109,17 @@ func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, |
| 37 | + // MulOp |
| 38 | + |
| 39 | + // CHECK-LABEL: @mul_fold_cst |
| 40 | +-func.func @mul_fold_cst() -> (tensor<i32>, tensor<f32>) { |
| 41 | ++func.func @mul_fold_cst() -> (tensor<i32>, tensor<f32>, tensor<?xf32>) { |
| 42 | + %cst = stablehlo.constant dense<2> : tensor<i32> |
| 43 | + %cst_1 = stablehlo.constant dense<2.0> : tensor<f32> |
| 44 | ++ %cst_2 = stablehlo.constant dense<2.0> : tensor<1xf32> |
| 45 | + // CHECK: stablehlo.constant dense<4> : tensor<i32> |
| 46 | + // CHECK: stablehlo.constant dense<4.0{{.*}}> : tensor<f32> |
| 47 | ++ // CHECK: stablehlo.multiply |
| 48 | + %0 = stablehlo.multiply %cst, %cst : tensor<i32> |
| 49 | + %1 = stablehlo.multiply %cst_1, %cst_1 : tensor<f32> |
| 50 | +- return %0, %1 : tensor<i32>, tensor<f32> |
| 51 | ++ %2 = stablehlo.multiply %cst_2, %cst_2 : (tensor<1xf32>, tensor<1xf32>) -> tensor<?xf32> |
| 52 | ++ return %0, %1, %2 : tensor<i32>, tensor<f32>, tensor<?xf32> |
| 53 | + } |
| 54 | + |
| 55 | + // ----- |
| 56 | +@@ -122,16 +128,21 @@ func.func @mul_fold_cst() -> (tensor<i32>, tensor<f32>) { |
| 57 | + // SubtractOp |
| 58 | + |
| 59 | + // CHECK-LABEL: @subtract_fold_cst |
| 60 | +-func.func @subtract_fold_cst() -> (tensor<i32>, tensor<f32>) { |
| 61 | ++func.func @subtract_fold_cst() -> (tensor<i32>, tensor<f32>, tensor<?xf32>) { |
| 62 | + %cst = stablehlo.constant dense<1> : tensor<i32> |
| 63 | + %cst_1 = stablehlo.constant dense<3> : tensor<i32> |
| 64 | + %cst_2 = stablehlo.constant dense<1.0> : tensor<f32> |
| 65 | + %cst_3 = stablehlo.constant dense<3.0> : tensor<f32> |
| 66 | +- // CHECK: stablehlo.constant dense<2> : tensor<i32> |
| 67 | +- // CHECK: stablehlo.constant dense<2.0{{.*}}> : tensor<f32> |
| 68 | ++ %cst_4 = stablehlo.constant dense<4.0> : tensor<1xf32> |
| 69 | ++ %cst_5 = stablehlo.constant dense<5.0> : tensor<1xf32> |
| 70 | ++ // CHECK: %[[V1:.+]] = stablehlo.constant dense<2> : tensor<i32> |
| 71 | ++ // CHECK: %[[V2:.+]] = stablehlo.constant dense<2.0{{.*}}> : tensor<f32> |
| 72 | ++ // CHECK: %[[V3:.+]] = stablehlo.subtract |
| 73 | ++ // CHECK: return %[[V1]], %[[V2]], %[[V3]] |
| 74 | + %0 = stablehlo.subtract %cst_1, %cst : tensor<i32> |
| 75 | + %1 = stablehlo.subtract %cst_3, %cst_2 : tensor<f32> |
| 76 | +- return %0, %1 : tensor<i32>, tensor<f32> |
| 77 | ++ %2 = stablehlo.subtract %cst_4, %cst_5 : (tensor<1xf32>, tensor<1xf32>) -> tensor<?xf32> |
| 78 | ++ return %0, %1, %2 : tensor<i32>, tensor<f32>, tensor<?xf32> |
| 79 | + } |
| 80 | + |
| 81 | + // ----- |
| 82 | +diff --git a/stablehlo/transforms/StablehloAggressiveFolder.cpp b/stablehlo/transforms/StablehloAggressiveFolder.cpp |
| 83 | +index a9107514..dadc14fb 100644 |
| 84 | +--- a/stablehlo/transforms/StablehloAggressiveFolder.cpp |
| 85 | ++++ b/stablehlo/transforms/StablehloAggressiveFolder.cpp |
| 86 | +@@ -257,6 +257,9 @@ struct FoldAddOpPattern final : OpRewritePattern<mlir::stablehlo::AddOp> { |
| 87 | + |
| 88 | + LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, |
| 89 | + PatternRewriter& rewriter) const override { |
| 90 | ++ if (failed(validateResultTypeForEval(rewriter, op, op.getType()))) |
| 91 | ++ return failure(); |
| 92 | ++ |
| 93 | + Value lhs = op.getLhs(); |
| 94 | + Value rhs = op.getRhs(); |
| 95 | + |
| 96 | +@@ -548,6 +551,9 @@ struct FoldMulOpPattern final : OpRewritePattern<mlir::stablehlo::MulOp> { |
| 97 | + |
| 98 | + LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, |
| 99 | + PatternRewriter& rewriter) const override { |
| 100 | ++ if (failed(validateResultTypeForEval(rewriter, op, op.getType()))) |
| 101 | ++ return failure(); |
| 102 | ++ |
| 103 | + auto elemType = op.getType().getElementType(); |
| 104 | + Value lhs = op.getLhs(); |
| 105 | + Value rhs = op.getRhs(); |
| 106 | +@@ -747,6 +753,9 @@ struct FoldSubtractOpPattern final |
| 107 | + |
| 108 | + LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, |
| 109 | + PatternRewriter& rewriter) const override { |
| 110 | ++ if (failed(validateResultTypeForEval(rewriter, op, op.getType()))) |
| 111 | ++ return failure(); |
| 112 | ++ |
| 113 | + Value lhs = op.getLhs(); |
| 114 | + Value rhs = op.getRhs(); |
| 115 | + |
| 116 | +-- |
| 117 | +2.47.0 |
| 118 | + |
0 commit comments