From 219796cc67acfcaeec3b8a4e0b3108f48b6eee22 Mon Sep 17 00:00:00 2001 From: Zero Rains Date: Mon, 23 Sep 2024 18:28:06 +0800 Subject: [PATCH] [Prim][PIR] Support dynamic shape for bmm op (#68357) * support dynamic shape for bmm op * fix the bug --- paddle/fluid/primitive/base/decomp_trans.cc | 3 +- paddle/fluid/primitive/composite/composite.h | 5 +-- .../test_prim_sub_graph_dynamic_shape.py | 38 +++++++++++++++++++ 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/primitive/base/decomp_trans.cc b/paddle/fluid/primitive/base/decomp_trans.cc index ecb745f56083d..648ad61eb92c0 100644 --- a/paddle/fluid/primitive/base/decomp_trans.cc +++ b/paddle/fluid/primitive/base/decomp_trans.cc @@ -49,8 +49,9 @@ std::unordered_set decomp_op_contain_none = { "pd_op.instance_norm", }; // + std::unordered_set dynamic_shape_blacklist = { - "pd_op.squeeze", "pd_op.unsqueeze", "pd_op.bmm", "pd_op.flatten"}; + "pd_op.squeeze", "pd_op.unsqueeze", "pd_op.flatten"}; namespace { std::set StringSplit(const std::string& str) { diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index c8eb4dfa2e4b7..313f7ab50ac2e 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -330,7 +330,7 @@ Tensor bmm_decomp(const Tensor& x, const Tensor& y) { auto x_shape = phi::vectorize(x.dims()); auto y_shape = phi::vectorize(y.dims()); - if (x_shape[0] != y_shape[0]) { + if (x_shape[0] != y_shape[0] && x_shape[0] != -1 && y_shape[0] != -1) { PADDLE_THROW(common::errors::InvalidArgument( "Input(X) and Input(Y) must have the same batch size in BmmOp, " "but received X's batch size: [%s]," @@ -339,7 +339,7 @@ Tensor bmm_decomp(const Tensor& x, const Tensor& y) { y_shape[0])); } - if (x_shape[2] != y_shape[1]) { + if (x_shape[2] != y_shape[1] && x_shape[2] != -1 && y_shape[1] != -1) { PADDLE_THROW(common::errors::InvalidArgument( "Input(X)'s width must be equal with Input(Y)'s height in BmmOp," "but receive X's width: [%s]," @@ -347,7 +347,6 @@ Tensor bmm_decomp(const Tensor& x, const Tensor& y) { x_shape[2], y_shape[1])); } - return matmul(x, y, false, false); } diff --git a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py index 7c244cda33e41..7319217a16e9c 100644 --- a/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py +++ b/test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py @@ -128,6 +128,10 @@ def squared_l2_norm_net(x): return paddle._C_ops.squared_l2_norm(x) +def bmm_net(x, y): + return paddle.bmm(x, y) + + def elu_net(x): return paddle.nn.functional.elu(x, 1.0) @@ -586,6 +590,40 @@ def setUp(self): self.tol = 1e-6 +class TestPrimBmm1(TestPrimTwo): + def setUp(self): + np.random.seed(2023) + self.x_shape = [30, 40, 50] + self.y_shape = [30, 50, 60] + self.dtype_x = "float32" + self.dtype_y = "float32" + self.init_x_shape = [None, None, 50] + self.init_y_shape = [None, None, 60] + self.x = np.random.random(self.x_shape).astype(self.dtype_x) + self.y = np.random.random(self.y_shape).astype(self.dtype_y) + self.net = bmm_net + self.necessary_ops = "pd_op.bmm" + self.enable_cinn = False + self.tol = 1e-6 + + +class TestPrimBmm2(TestPrimTwo): + def setUp(self): + np.random.seed(2023) + self.x_shape = [30, 40, 50] + self.y_shape = [30, 50, 60] + self.dtype_x = "float32" + self.dtype_y = "float32" + self.init_x_shape = [30, None, None] + self.init_y_shape = [None, None, 60] + self.x = np.random.random(self.x_shape).astype(self.dtype_x) + self.y = np.random.random(self.y_shape).astype(self.dtype_y) + self.net = bmm_net + self.necessary_ops = "pd_op.bmm" + self.enable_cinn = False + self.tol = 1e-6 + + class TestPrimBceLoss(TestPrimTwo): def setUp(self): np.random.seed(2023)