Skip to content

Commit 338da1f

Browse files
[mlir-tensorrt] Add additional missing StableHLO patch
Added a patch that was missing from the last StableHLO upgrade. This patch addresses some issues mentioned in openxla/stablehlo#2634. An additional test is added to mlir-tensorrt as a regression test.
1 parent b71a868 commit 338da1f

File tree

3 files changed

+140
-1
lines changed

3 files changed

+140
-1
lines changed

mlir-tensorrt/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ if(MLIR_TRT_ENABLE_HLO AND NOT TARGET StablehloOps)
186186
GIT_REPOSITORY "https://github.com/openxla/stablehlo.git"
187187
PATCHES
188188
"${CMAKE_CURRENT_LIST_DIR}/build_tools/patches/stablehlo/0001-transforms-Fix-simplification-patterns-for-stablehlo.patch"
189+
"${CMAKE_CURRENT_LIST_DIR}/build_tools/patches/stablehlo/0002-Fix-a-couple-missing-checks-for-static-shapes-in-sta.patch"
189190
OPTIONS
190191
"STABLEHLO_ENABLE_BINDINGS_PYTHON ${MLIR_TRT_ENABLE_PYTHON}"
191192
"STABLEHLO_BUILD_EMBEDDED ON"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
From fb0378d09cebb74da6ca253f6b41241a26bab43e Mon Sep 17 00:00:00 2001
2+
From: Christopher Bate <[email protected]>
3+
Date: Wed, 27 Nov 2024 00:10:11 +0000
4+
Subject: [PATCH] Fix a couple missing checks for static shapes in
5+
`stablehlo-aggressive-folder`
6+
7+
---
8+
.../stablehlo_aggressive_folder.mlir | 27 +++++++++++++------
9+
.../transforms/StablehloAggressiveFolder.cpp | 9 +++++++
10+
2 files changed, 28 insertions(+), 8 deletions(-)
11+
12+
diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
13+
index 5b21a10d..c90c89c6 100644
14+
--- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
15+
+++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
16+
@@ -4,14 +4,17 @@
17+
// AddOp
18+
19+
// CHECK-LABEL: @add_fold_cst
20+
-func.func @add_fold_cst() -> (tensor<i32>, tensor<f32>) {
21+
+func.func @add_fold_cst() -> (tensor<i32>, tensor<f32>, tensor<?xf32>) {
22+
%cst = stablehlo.constant dense<1> : tensor<i32>
23+
%cst_1 = stablehlo.constant dense<1.0> : tensor<f32>
24+
+ %cst_2 = stablehlo.constant dense<2.0> : tensor<1xf32>
25+
// CHECK: stablehlo.constant dense<2> : tensor<i32>
26+
// CHECK: stablehlo.constant dense<2.0{{.*}}> : tensor<f32>
27+
+ // CHECK: stablehlo.add
28+
%0 = stablehlo.add %cst, %cst : tensor<i32>
29+
%1 = stablehlo.add %cst_1, %cst_1 : tensor<f32>
30+
- return %0, %1 : tensor<i32>, tensor<f32>
31+
+ %2 = stablehlo.add %cst_2, %cst_2 : (tensor<1xf32>, tensor<1xf32>) -> tensor<?xf32>
32+
+ return %0, %1, %2 : tensor<i32>, tensor<f32>, tensor<?xf32>
33+
}
34+
35+
// -----
36+
@@ -106,14 +109,17 @@ func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>,
37+
// MulOp
38+
39+
// CHECK-LABEL: @mul_fold_cst
40+
-func.func @mul_fold_cst() -> (tensor<i32>, tensor<f32>) {
41+
+func.func @mul_fold_cst() -> (tensor<i32>, tensor<f32>, tensor<?xf32>) {
42+
%cst = stablehlo.constant dense<2> : tensor<i32>
43+
%cst_1 = stablehlo.constant dense<2.0> : tensor<f32>
44+
+ %cst_2 = stablehlo.constant dense<2.0> : tensor<1xf32>
45+
// CHECK: stablehlo.constant dense<4> : tensor<i32>
46+
// CHECK: stablehlo.constant dense<4.0{{.*}}> : tensor<f32>
47+
+ // CHECK: stablehlo.multiply
48+
%0 = stablehlo.multiply %cst, %cst : tensor<i32>
49+
%1 = stablehlo.multiply %cst_1, %cst_1 : tensor<f32>
50+
- return %0, %1 : tensor<i32>, tensor<f32>
51+
+ %2 = stablehlo.multiply %cst_2, %cst_2 : (tensor<1xf32>, tensor<1xf32>) -> tensor<?xf32>
52+
+ return %0, %1, %2 : tensor<i32>, tensor<f32>, tensor<?xf32>
53+
}
54+
55+
// -----
56+
@@ -122,16 +128,21 @@ func.func @mul_fold_cst() -> (tensor<i32>, tensor<f32>) {
57+
// SubtractOp
58+
59+
// CHECK-LABEL: @subtract_fold_cst
60+
-func.func @subtract_fold_cst() -> (tensor<i32>, tensor<f32>) {
61+
+func.func @subtract_fold_cst() -> (tensor<i32>, tensor<f32>, tensor<?xf32>) {
62+
%cst = stablehlo.constant dense<1> : tensor<i32>
63+
%cst_1 = stablehlo.constant dense<3> : tensor<i32>
64+
%cst_2 = stablehlo.constant dense<1.0> : tensor<f32>
65+
%cst_3 = stablehlo.constant dense<3.0> : tensor<f32>
66+
- // CHECK: stablehlo.constant dense<2> : tensor<i32>
67+
- // CHECK: stablehlo.constant dense<2.0{{.*}}> : tensor<f32>
68+
+ %cst_4 = stablehlo.constant dense<4.0> : tensor<1xf32>
69+
+ %cst_5 = stablehlo.constant dense<5.0> : tensor<1xf32>
70+
+ // CHECK: %[[V1:.+]] = stablehlo.constant dense<2> : tensor<i32>
71+
+ // CHECK: %[[V2:.+]] = stablehlo.constant dense<2.0{{.*}}> : tensor<f32>
72+
+ // CHECK: %[[V3:.+]] = stablehlo.subtract
73+
+ // CHECK: return %[[V1]], %[[V2]], %[[V3]]
74+
%0 = stablehlo.subtract %cst_1, %cst : tensor<i32>
75+
%1 = stablehlo.subtract %cst_3, %cst_2 : tensor<f32>
76+
- return %0, %1 : tensor<i32>, tensor<f32>
77+
+ %2 = stablehlo.subtract %cst_4, %cst_5 : (tensor<1xf32>, tensor<1xf32>) -> tensor<?xf32>
78+
+ return %0, %1, %2 : tensor<i32>, tensor<f32>, tensor<?xf32>
79+
}
80+
81+
// -----
82+
diff --git a/stablehlo/transforms/StablehloAggressiveFolder.cpp b/stablehlo/transforms/StablehloAggressiveFolder.cpp
83+
index a9107514..dadc14fb 100644
84+
--- a/stablehlo/transforms/StablehloAggressiveFolder.cpp
85+
+++ b/stablehlo/transforms/StablehloAggressiveFolder.cpp
86+
@@ -257,6 +257,9 @@ struct FoldAddOpPattern final : OpRewritePattern<mlir::stablehlo::AddOp> {
87+
88+
LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op,
89+
PatternRewriter& rewriter) const override {
90+
+ if (failed(validateResultTypeForEval(rewriter, op, op.getType())))
91+
+ return failure();
92+
+
93+
Value lhs = op.getLhs();
94+
Value rhs = op.getRhs();
95+
96+
@@ -548,6 +551,9 @@ struct FoldMulOpPattern final : OpRewritePattern<mlir::stablehlo::MulOp> {
97+
98+
LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
99+
PatternRewriter& rewriter) const override {
100+
+ if (failed(validateResultTypeForEval(rewriter, op, op.getType())))
101+
+ return failure();
102+
+
103+
auto elemType = op.getType().getElementType();
104+
Value lhs = op.getLhs();
105+
Value rhs = op.getRhs();
106+
@@ -747,6 +753,9 @@ struct FoldSubtractOpPattern final
107+
108+
LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op,
109+
PatternRewriter& rewriter) const override {
110+
+ if (failed(validateResultTypeForEval(rewriter, op, op.getType())))
111+
+ return failure();
112+
+
113+
Value lhs = op.getLhs();
114+
Value rhs = op.getRhs();
115+
116+
--
117+
2.47.0
118+

mlir-tensorrt/test/Dialect/Plan/segmentation-pipeline.mlir

+21-1
Original file line numberDiff line numberDiff line change
@@ -224,4 +224,24 @@ builtin.module @simple_gather_dynamic attributes {
224224
// CHECK-DAG: %[[v1:.+]] = stablehlo.reshape %[[v0]] : (tensor<i32>) -> tensor<1xi32>
225225
// CHECK-DAG: %[[v2:.+]] = stablehlo.concatenate %[[c]], %[[v1]], %[[c_0]], %[[c_0]]
226226
// CHECK-DAG: %[[v3:.+]] = "stablehlo.dynamic_gather"(%[[arg1]], %[[arg0]], %[[v2]])
227-
// CHECK-DAG: return %[[v3]] : tensor<?x?x256x256xi32>
227+
// CHECK-DAG: return %[[v3]] : tensor<?x?x256x256xi32>
228+
229+
// -----
230+
231+
builtin.module attributes {
232+
plan.cluster_kinds = [
233+
#plan.tensorrt_cluster<benefit = 1, disallow_shape_tensor_calculations=true, tensorrt_major_version=10>,
234+
#plan.host_cluster<benefit = 0>
235+
]
236+
} {
237+
func.func @static_type_refinement() -> tensor<?x?xi32>{
238+
%c_0 = stablehlo.constant dense<1> : tensor<1x1xi32>
239+
%c_1 = stablehlo.constant dense<2> : tensor<1x1xi32>
240+
%0 = stablehlo.subtract %c_0, %c_1 : (tensor<1x1xi32>, tensor<1x1xi32>) -> tensor<?x?xi32>
241+
return %0 : tensor<?x?xi32>
242+
}
243+
}
244+
245+
// CHECK-LABEL: func.func @static_type_refinement() -> tensor<1x1xi32>
246+
// CHECK-LABEL: tensorrt.module
247+
// CHECK: stablehlo.subtract {{.*}} : tensor<1x1xi32>

0 commit comments

Comments
 (0)