Skip to content

Commit

Permalink
[Prim][PIR] Support dynamic shape for bmm op (#68357)
Browse files Browse the repository at this point in the history
* support dynamic shape for bmm op

* fix the bug
  • Loading branch information
zeroRains authored Sep 23, 2024
1 parent eb15b78 commit 219796c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/primitive/base/decomp_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ std::unordered_set<std::string> decomp_op_contain_none = {
"pd_op.instance_norm",
};
//

std::unordered_set<std::string> dynamic_shape_blacklist = {
"pd_op.squeeze", "pd_op.unsqueeze", "pd_op.bmm", "pd_op.flatten"};
"pd_op.squeeze", "pd_op.unsqueeze", "pd_op.flatten"};

namespace {
std::set<std::string> StringSplit(const std::string& str) {
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ Tensor bmm_decomp(const Tensor& x, const Tensor& y) {
auto x_shape = phi::vectorize(x.dims());
auto y_shape = phi::vectorize(y.dims());

if (x_shape[0] != y_shape[0]) {
if (x_shape[0] != y_shape[0] && x_shape[0] != -1 && y_shape[0] != -1) {
PADDLE_THROW(common::errors::InvalidArgument(
"Input(X) and Input(Y) must have the same batch size in BmmOp, "
"but received X's batch size: [%s],"
Expand All @@ -339,15 +339,14 @@ Tensor bmm_decomp(const Tensor& x, const Tensor& y) {
y_shape[0]));
}

if (x_shape[2] != y_shape[1]) {
if (x_shape[2] != y_shape[1] && x_shape[2] != -1 && y_shape[1] != -1) {
PADDLE_THROW(common::errors::InvalidArgument(
"Input(X)'s width must be equal with Input(Y)'s height in BmmOp,"
"but receive X's width: [%s],"
"Y's height: [%s].",
x_shape[2],
y_shape[1]));
}

return matmul<T>(x, y, false, false);
}

Expand Down
38 changes: 38 additions & 0 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def squared_l2_norm_net(x):
return paddle._C_ops.squared_l2_norm(x)


def bmm_net(x, y):
return paddle.bmm(x, y)


def elu_net(x):
return paddle.nn.functional.elu(x, 1.0)

Expand Down Expand Up @@ -586,6 +590,40 @@ def setUp(self):
self.tol = 1e-6


class TestPrimBmm1(TestPrimTwo):
def setUp(self):
np.random.seed(2023)
self.x_shape = [30, 40, 50]
self.y_shape = [30, 50, 60]
self.dtype_x = "float32"
self.dtype_y = "float32"
self.init_x_shape = [None, None, 50]
self.init_y_shape = [None, None, 60]
self.x = np.random.random(self.x_shape).astype(self.dtype_x)
self.y = np.random.random(self.y_shape).astype(self.dtype_y)
self.net = bmm_net
self.necessary_ops = "pd_op.bmm"
self.enable_cinn = False
self.tol = 1e-6


class TestPrimBmm2(TestPrimTwo):
def setUp(self):
np.random.seed(2023)
self.x_shape = [30, 40, 50]
self.y_shape = [30, 50, 60]
self.dtype_x = "float32"
self.dtype_y = "float32"
self.init_x_shape = [30, None, None]
self.init_y_shape = [None, None, 60]
self.x = np.random.random(self.x_shape).astype(self.dtype_x)
self.y = np.random.random(self.y_shape).astype(self.dtype_y)
self.net = bmm_net
self.necessary_ops = "pd_op.bmm"
self.enable_cinn = False
self.tol = 1e-6


class TestPrimBceLoss(TestPrimTwo):
def setUp(self):
np.random.seed(2023)
Expand Down

0 comments on commit 219796c

Please sign in to comment.