diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index f3e0f47385..cb43bd0d9d 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -246,17 +246,6 @@ Placeholder NetBuilder::CreateInput(const Variable& var) { return Placeholder(var); } -Variable NetBuilder::FillConstant( - const std::vector& shape, float value, const std::string& name, const std::string& dtype, bool force_cpu) { - auto out = - CustomInstr("fill_constant", {}, {{"shape", shape}, {"value", value}, {"dtype", dtype}, {"force_cpu", force_cpu}}) - .front(); - if (!name.empty()) { - out.set_id(cinn::utils::TransValidVarName(name)); - } - return out; -} - Variable NetBuilder::FillConstant(const std::vector& shape, const std::string& str_value, const std::string& name, diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index be7f34b95a..6fb8bcdbf4 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -350,7 +350,7 @@ class NetBuilder { const std::string& id_hint = ""); /** - * @brief Create constant tensor with the specific value/vector and type, the type is infered from value. + * @brief Create constant tensor with the specific value/vector and type * @param value The constant value to be set. * @param name The name of output variable. * @return The result variable. @@ -408,11 +408,21 @@ class NetBuilder { * @param force_cpu Whether the variable should force placed in cpu, default in device memory. Default is false. * @return The result variable. */ + template Variable FillConstant(const cinn::utils::ShapeType& shape, - float value, + T value, const std::string& name, const std::string& dtype, - bool force_cpu = false); + bool force_cpu = false) { + auto out = + CustomInstr( + "fill_constant", {}, {{"shape", shape}, {"value", value}, {"dtype", dtype}, {"force_cpu", force_cpu}}) + .front(); + if (!name.empty()) { + out.set_id(cinn::utils::TransValidVarName(name)); + } + return out; + } /** * @brief The op return a variable with the specific string value, shape and type. @@ -442,7 +452,7 @@ class NetBuilder { T value, const std::string& name = "", bool force_cpu = false) { - return FillConstant(shape, static_cast(value), name, common::Type2Str(common::type_of()), force_cpu); + return FillConstant(shape, value, name, common::Type2Str(common::type_of()), force_cpu); } /** diff --git a/cinn/hlir/op/elementwise.cc b/cinn/hlir/op/elementwise.cc index 1eafbe5ec3..598311837a 100644 --- a/cinn/hlir/op/elementwise.cc +++ b/cinn/hlir/op/elementwise.cc @@ -201,6 +201,7 @@ std::shared_ptr StrategyForConstScalar(const framework::NodeAttr &at framework::CINNCompute const_scalar_compute([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of const_float compute is empty! Please check."; auto scalar = GetScalarExpr(attrs.attr_store.at("value")); + auto scalar_type = out_type.at(0); CINNValuePack pack_args = args[0]; std::string tensor_name = UniqName("const_scalar_Out"); if (FLAGS_cinn_ir_schedule) { @@ -210,7 +211,12 @@ std::shared_ptr StrategyForConstScalar(const framework::NodeAttr &at } auto out = lang::Compute( - {Expr(1)}, [=](const std::vector &indice) { return scalar; }, tensor_name); + {Expr(1)}, + [=](const std::vector &indice) { + auto res = (scalar_type == scalar->type()) ? scalar : ir::Cast::Make(scalar_type, scalar); + return res; + }, + tensor_name); CHECK(out.defined()) << "can't create const scalar with the given type " << out_type[0]; auto stages = CreateStages({out}); *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; @@ -229,9 +235,16 @@ std::vector InferShapeForConstScalar(const std::vector &inputs } std::vector InferDtypeForConstScalar(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(attrs.count("value")); - auto scalar = GetScalarExpr(attrs.at("value")); - auto out_type = scalar->type(); + Type out_type; + if (attrs.find("dtype") != attrs.end()) { + auto dtype_str = absl::get(attrs.at("dtype")); + if (!dtype_str.empty()) { + out_type = common::Str2Type(dtype_str); + } + } else { + auto scalar = GetScalarExpr(attrs.at("value")); + out_type = scalar->type(); + } VLOG(3) << "scalar type: " << out_type; return {out_type}; } @@ -356,10 +369,10 @@ std::vector> InferLayoutForFillConstant(const std::vect #define EXPAND_ATTR_TYPE(MACRO) \ MACRO(bool) \ - MACRO(float) \ MACRO(int) \ MACRO(int64_t) \ - MACRO(double) + MACRO(double) \ + MACRO(float) std::shared_ptr StrategyForAssignValue(const framework::NodeAttr &attrs, const std::vector &inputs, diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index 1cb719538e..ba0ae30c19 100644 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -65,8 +65,6 @@ static const char *SnakeName(const char *name) { #define EXPAND_CINN_SUPPORT_TYPE(EXPAND_MACRO) \ EXPAND_MACRO(bool) \ - EXPAND_MACRO(float) \ - EXPAND_MACRO(int) \ EXPAND_MACRO(int64_t) \ EXPAND_MACRO(double) @@ -405,14 +403,23 @@ void BindFrontend(pybind11::module *m) { #undef EXPAND_QUINTIC_VECTOR #undef EXPAND_SEXTIC_VECTOR #undef PY_REGISTER_CONSTANT_OP -#define PY_REGISTER_FILLCONSTANT_OP(TYPE__) \ - .def("fill_constant", \ - static_cast &, TYPE__, const std::string &, bool)>( \ - &NetBuilder::template FillConstant), \ - py::arg("shape"), \ - py::arg("value"), \ - py::arg("name") = "", \ +#define PY_REGISTER_FILLCONSTANT_OP(TYPE__) \ + .def("fill_constant", \ + static_cast &, TYPE__, const std::string &, const std::string &, bool)>( \ + &NetBuilder::FillConstant), \ + py::arg("shape"), \ + py::arg("value"), \ + py::arg("name") = "", \ + py::arg("dtype"), \ + py::arg("force_cpu") = false) \ + .def("fill_constant", \ + static_cast &, TYPE__, const std::string &, bool)>( \ + &NetBuilder::template FillConstant), \ + py::arg("shape"), \ + py::arg("value"), \ + py::arg("name") = "", \ py::arg("force_cpu") = false) EXPAND_CINN_SUPPORT_TYPE(PY_REGISTER_FILLCONSTANT_OP) #undef PY_REGISTER_FILLCONSTANT_OP @@ -462,15 +469,6 @@ void BindFrontend(pybind11::module *m) { .def("name", &NetBuilder::name) .def("__str__", [](NetBuilder &self) { return self.name(); }) .def("append_instruction", &NetBuilder::AppendInstruction, py::arg("instr")) - .def("fill_constant", - static_cast &, float, const std::string &, const std::string &, bool)>( - &NetBuilder::FillConstant), - py::arg("shape"), - py::arg("value"), - py::arg("name") = "", - py::arg("dtype"), - py::arg("force_cpu") = false) .def("fill_constant", static_cast &, const std::string &, const std::string &, const std::string &, bool)>( diff --git a/python/tests/ops/test_cbrt_op.py b/python/tests/ops/test_cbrt_op.py index 14a6385ac9..1ca112cdf9 100644 --- a/python/tests/ops/test_cbrt_op.py +++ b/python/tests/ops/test_cbrt_op.py @@ -14,26 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import paddle.nn.functional as F -import cinn -from cinn.frontend import * +import numpy as np from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestCbrtOp(OpTest): def setUp(self): - self.init_case() + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() - def init_case(self): + def prepare_inputs(self): self.inputs = { - "x": np.array([0, 1, 0.01, 27, 1000000, - 0.970299]).astype("float32") + "x": + self.random(self.case["shape"], self.case["dtype"], -100.0, 100.0), } def build_paddle_program(self, target): @@ -43,44 +43,101 @@ def build_paddle_program(self, target): def build_cinn_program(self, target): builder = NetBuilder("cbrt") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") + x = builder.create_input( + self.nptype2cinntype(self.inputs["x"].dtype), + self.inputs["x"].shape, "x") out = builder.cbrt(x) prog = builder.build() res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out]) - self.cinn_outputs = [res[0]] + self.cinn_outputs = res def test_check_results(self): - self.check_outputs_and_grads() - - -class TestCbrtCase1(TestCbrtOp): - def init_case(self): - self.inputs = { - "x": - np.array([0, 1, 0.01, 27, 1000000, 0.970299, 124483, - 13.7396]).astype("float32") - } - - -class TestCbrtCase2(TestCbrtOp): - def init_case(self): - self.inputs = { - "x": - np.array([[0, 1, 0.01, 27], [1000000, 0.970299, 124483, - 13.7396]]).astype("float32"), - } - - -class TestCbrtCase3(TestCbrtOp): - def init_case(self): - np.random.seed(0) - self.inputs = { - "x": np.random.random((32, 64)).astype("float32"), - } + self.check_outputs_and_grads( + max_relative_error=1e-3, max_absolute_error=1e-3) + + +class TestCbrtOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCbrtOpShape" + self.cls = TestCbrtOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [80, 40, 5, 7], + }, + { + "shape": [80, 1, 5, 7], + }, + { + "shape": [80, 3, 1024, 7], + }, + { + "shape": [10, 5, 1024, 2048], + }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] + + +class TestCbrtOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCbrtOpDtype" + self.cls = TestCbrtOp + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [5], + }, + { + "shape": [80, 40, 5, 7], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestCbrtOpShape().run() + TestCbrtOpDtype().run() diff --git a/python/tests/ops/test_ceil_op.py b/python/tests/ops/test_ceil_op.py index bf5a8189b1..1377849daf 100644 --- a/python/tests/ops/test_ceil_op.py +++ b/python/tests/ops/test_ceil_op.py @@ -14,27 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import cinn -from cinn.frontend import * from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestCeilOp(OpTest): def setUp(self): - self.init_case() + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() - def init_case(self): + def prepare_inputs(self): self.inputs = { - "x": np.random.random([ - 32, - 64, - ]).astype("float32") * 2 - 1 + "x": + self.random(self.case["shape"], self.case["dtype"], -100.0, 100.0), } def build_paddle_program(self, target): @@ -47,7 +45,9 @@ def build_paddle_program(self, target): # the forward result will be incorrect. def build_cinn_program(self, target): builder = NetBuilder("ceil") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") + x = builder.create_input( + self.nptype2cinntype(self.inputs["x"].dtype), + self.inputs["x"].shape, "x") out = builder.ceil(x) prog = builder.build() @@ -60,12 +60,85 @@ def test_check_results(self): self.check_outputs_and_grads() -class TestCeilCase1(TestCeilOp): - def init_case(self): - self.inputs = { - "x": np.random.random([10201, 50]).astype("float32") * 3 - 1 - } +class TestCeilOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCeilOpShape" + self.cls = TestCeilOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [80, 40, 5, 7], + }, + { + "shape": [80, 1, 5, 7], + }, + { + "shape": [80, 3, 1024, 7], + }, + { + "shape": [10, 5, 1024, 2048], + }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] + + +class TestCeilOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCeilOpDtype" + self.cls = TestCeilOp + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [5], + }, + { + "shape": [80, 40, 5, 7], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestCeilOpShape().run() + TestCeilOpDtype().run() diff --git a/python/tests/ops/test_cholesky_op.py b/python/tests/ops/test_cholesky_op.py index 5ef0a9df93..d0396e0abb 100644 --- a/python/tests/ops/test_cholesky_op.py +++ b/python/tests/ops/test_cholesky_op.py @@ -14,27 +14,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import numpy as np -from op_test import OpTest, OpTestTool import paddle -import cinn -from cinn.frontend import * from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestCholeskyOp(OpTest): def setUp(self): - self.init_case() - - def init_case(self): - matrix = self.random([3, 3], "float32") - matrix_t = np.transpose(matrix, [1, 0]) - x = np.dot(matrix, matrix_t) + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + if "batch_dim" in self.case and self.case["batch_dim"] > 0: + x = [] + for _ in range(self.case["batch_dim"]): + matrix = self.random(self.case["shape"], self.case["dtype"], + -1.0, 1.0) + matrix_t = np.transpose(matrix, [1, 0]) + x.append(np.dot(matrix, matrix_t)) + x = np.stack(x) + else: + matrix = self.random(self.case["shape"], self.case["dtype"], -1.0, + 1.0) + matrix_t = np.transpose(matrix, [1, 0]) + x = np.dot(matrix, matrix_t) self.inputs = {"x": x} - self.upper = False + self.upper = self.case["upper"] def build_paddle_program(self, target): x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) @@ -56,34 +67,165 @@ def test_check_results(self): self.check_outputs_and_grads() -class TestCholeskyCase1(TestCholeskyOp): - def init_case(self): - matrix = self.random([5, 5], "float64") - matrix_t = np.transpose(matrix, [1, 0]) - x = np.dot(matrix, matrix_t) - self.inputs = {"x": x} - self.upper = True - - -class TestCholeskyCase2(TestCholeskyOp): - def init_case(self): - matrix = self.random([3, 3], "float32") - matrix_t = np.transpose(matrix, [1, 0]) - x = np.dot(matrix, matrix_t) - x = x * np.ones(shape=(3, 3, 3)) - self.inputs = {"x": x} - self.upper = False - - -class TestCholeskyCase3(TestCholeskyOp): - def init_case(self): - matrix = self.random([3, 3], "float64") - matrix_t = np.transpose(matrix, [1, 0]) - x = np.dot(matrix, matrix_t) - x = x * np.ones(shape=(2, 3, 3, 3)) - self.inputs = {"x": x} - self.upper = True +class TestCholeskyOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpShape" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1, 1], + }, + { + "shape": [8, 8], + }, + { + "shape": [10, 10], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "upper": False + }, + ] + + +class TestCholeskyOpLargeShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpLargeShape" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1024, 1024], + }, + { + "shape": [2048, 2048], + }, + ] + self.dtypes = [ + { + "dtype": "float64" + }, + ] + self.attrs = [ + { + "upper": False, + "batch_dim": 2 + }, + { + "upper": False, + "batch_dim": 4 + }, + { + "upper": True, + "batch_dim": 8 + }, + ] + + +class TestCholeskyOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpDtype" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1, 1], + }, + { + "shape": [8, 8], + }, + { + "shape": [10, 10], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [ + { + "upper": False + }, + ] + + +class TestCholeskyOpBatch(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpBatch" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1, 1], + }, + { + "shape": [8, 8], + }, + { + "shape": [10, 10], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "upper": False, + "batch_dim": 1 + }, + { + "upper": False, + "batch_dim": 4 + }, + { + "upper": False, + "batch_dim": 8 + }, + ] + + +class TestCholeskyOpAttrs(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpAttrs" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1, 1], + }, + { + "shape": [8, 8], + }, + { + "shape": [10, 10], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [ + { + "upper": True, + }, + ] if __name__ == "__main__": - unittest.main() + TestCholeskyOpShape().run() + TestCholeskyOpLargeShape().run() + TestCholeskyOpDtype().run() + TestCholeskyOpBatch().run() + TestCholeskyOpAttrs().run() diff --git a/python/tests/ops/test_concat_op.py b/python/tests/ops/test_concat_op.py index c1d68a45a6..23816e0774 100755 --- a/python/tests/ops/test_concat_op.py +++ b/python/tests/ops/test_concat_op.py @@ -14,27 +14,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import cinn -from cinn.frontend import * from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestConcatOp(OpTest): def setUp(self): - self.init_case() + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() - def init_case(self): - self.inputs = { - "x1": np.random.random([10201, 50]).astype("float32"), - "x2": np.random.random((10201, 50)).astype("float32") - } - self.axis = 0 + def prepare_inputs(self): + self.inputs = {} + self.axis = self.case["axis"] + dtype = self.case["dtype"] + shapes = self.case["shapes"] + for i, shape in enumerate(shapes): + name = "x" + str(i) + self.inputs[name] = self.random(shape, dtype) def paddle_inputs(self, inputs): return [ @@ -44,7 +46,8 @@ def paddle_inputs(self, inputs): def cinn_inputs(self, builder, inputs): return [ - builder.create_input(Float(32), data.shape, name) + builder.create_input( + self.nptype2cinntype(data.dtype), data.shape, name) for name, data in inputs.items() ] @@ -73,44 +76,287 @@ def test_check_results(self): self.check_outputs_and_grads(all_equal=True) -class TestConcatCase1(TestConcatOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([4, 3]).astype("float32"), - "x2": np.random.random((8, 3)).astype("float32") - } - self.axis = 0 +class TestConcatOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConcatOpShape" + self.cls = TestConcatOp + self.inputs = [ + { + "shapes": [[10], [6]], + }, + { + "shapes": [[8, 5], [8, 5]], + }, + { + "shapes": [[10, 3, 5], [4, 3, 5]], + }, + { + "shapes": [[80, 40, 5, 7], [20, 40, 5, 7]], + }, + { + "shapes": [[80, 1, 5, 7], [8, 1, 5, 7]], + }, + { + "shapes": [[80, 3, 1024, 7], [100, 3, 1024, 7]], + }, + { + "shapes": [[1, 5, 1024, 2048], [2, 5, 1024, 2048]], + }, + { + "shapes": [[1], [1]], + }, + { + "shapes": [[512], [512]], + }, + { + "shapes": [[1024], [512]], + }, + { + "shapes": [[2048], [4096]], + }, + { + "shapes": [[1, 1, 1, 1], [1, 1, 1, 1]], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "axis": 0 + }, + ] -class TestConcatCase2(TestConcatOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([2, 4, 8]).astype("float32"), - "x2": np.random.random((2, 4, 4)).astype("float32") - } - self.axis = -1 +class TestConcatOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConcatOpDtype" + self.cls = TestConcatOp + self.inputs = [ + { + "shapes": [[10], [6]], + }, + { + "shapes": [[8, 5], [8, 5]], + }, + { + "shapes": [[10, 3, 5], [4, 3, 5]], + }, + { + "shapes": [[80, 40, 5, 7], [20, 40, 5, 7]], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + { + "dtype": "bool" + }, + { + "dtype": "uint8" + }, + { + "dtype": "int8" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + ] + self.attrs = [ + { + "axis": 0 + }, + ] -class TestConcatCase3(TestConcatOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([2, 8, 4]).astype("float32"), - "x2": np.random.random((2, 4, 4)).astype("float32") - } - self.axis = 1 +class TestConcatOpMultipleInputs(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConcatOpMultipleInputs" + self.cls = TestConcatOp + self.inputs = [ + # 1D tensor with 1~4 inputs + { + "shapes": [[10]], + "axis": 0 + }, + { + "shapes": [[10], [6]], + "axis": 0 + }, + { + "shapes": [[10], [6], [8]], + "axis": 0 + }, + { + "shapes": [[10], [6], [10], [6]], + "axis": 0 + }, + # 2D tensor with 1~4 inputs + { + "shapes": [[8, 5]], + "axis": 1 + }, + { + "shapes": [[8, 5], [8, 8]], + "axis": 1 + }, + { + "shapes": [[8, 5], [8, 5], [16, 5]], + "axis": 0 + }, + { + "shapes": [[8, 5], [8, 5], [8, 5], [8, 5]], + "axis": 0 + }, + # 3D tensor with 1~4 inputs + { + "shapes": [[10, 3, 5]], + "axis": 0 + }, + { + "shapes": [[10, 3, 5], [10, 7, 5]], + "axis": 1 + }, + { + "shapes": [[10, 3, 5], [10, 3, 6], [10, 3, 7]], + "axis": 2 + }, + { + "shapes": [[10, 3, 5], [4, 3, 5], [2, 3, 5]], + "axis": 0 + }, + # 4D tensor with 1~4 inputs + { + "shapes": [[80, 1, 5, 7]], + "axis": 0 + }, + { + "shapes": [[80, 1, 5, 7], [80, 79, 5, 7]], + "axis": 1 + }, + { + "shapes": [[80, 1, 50, 7], [80, 1, 5, 7], [80, 1, 10, 7]], + "axis": 2 + }, + { + "shapes": [[80, 1, 5, 17], [80, 1, 5, 27], [80, 1, 5, 37], + [80, 1, 5, 47]], + "axis": + 3 + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] -class TestConcatCase5(TestConcatOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([1, 16]).astype("float32"), - "x2": np.random.random([2, 16]).astype("float32"), - "x3": np.random.random([3, 16]).astype("float32"), - "x4": np.random.random([4, 16]).astype("float32"), - "x5": np.random.random([5, 16]).astype("float32") - } - self.axis = 0 +class TestConcatOpAttrs(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConcatOpAttrs" + self.cls = TestConcatOp + self.inputs = [ + # 1D tensor + { + "shapes": [[10], [8]], + "axis": 0 + }, + { + "shapes": [[10], [6]], + "axis": -1 + }, + # 2D tensor + { + "shapes": [[8, 5], [10, 5]], + "axis": 0 + }, + { + "shapes": [[8, 5], [8, 8]], + "axis": 1 + }, + # 3D tensor + { + "shapes": [[10, 3, 5], [10, 3, 5]], + "axis": 0 + }, + { + "shapes": [[10, 3, 5], [10, 7, 5]], + "axis": 1 + }, + { + "shapes": [[10, 3, 15], [10, 3, 5]], + "axis": 2 + }, + { + "shapes": [[10, 3, 7], [10, 3, 5]], + "axis": -1 + }, + { + "shapes": [[10, 3, 5], [10, 7, 5]], + "axis": -2 + }, + { + "shapes": [[10, 7, 5], [20, 7, 5]], + "axis": -3 + }, + # 4D tensor + { + "shapes": [[80, 1, 5, 7], [80, 1, 5, 7]], + "axis": 0 + }, + { + "shapes": [[80, 1, 5, 7], [80, 79, 5, 7]], + "axis": 1 + }, + { + "shapes": [[80, 1, 5, 7], [80, 1, 10, 7]], + "axis": 2 + }, + { + "shapes": [[80, 1, 5, 7], [80, 1, 5, 7]], + "axis": 3 + }, + { + "shapes": [[80, 1, 5, 7], [80, 1, 5, 13]], + "axis": -1 + }, + { + "shapes": [[80, 1, 5, 7], [80, 1, 5, 7]], + "axis": -2 + }, + { + "shapes": [[80, 15, 5, 7], [80, 5, 5, 7]], + "axis": -3 + }, + { + "shapes": [[80, 1, 5, 7], [20, 1, 5, 7]], + "axis": -4 + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestConcatOpShape().run() + TestConcatOpDtype().run() + TestConcatOpMultipleInputs().run() + TestConcatOpAttrs().run() diff --git a/python/tests/ops/test_constant_op.py b/python/tests/ops/test_constant_op.py index 6407d58512..424aa7c56a 100644 --- a/python/tests/ops/test_constant_op.py +++ b/python/tests/ops/test_constant_op.py @@ -14,99 +14,150 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import paddle.nn.functional as F -import cinn -from cinn.frontend import * from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestConstantOp(OpTest): def setUp(self): - self.init_case() - - def init_case(self): - self.value = 1.0 - self.name = 'x' - self.dtype = "float32" + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + self.name = "x" + dtype = self.case["dtype"] + if "constant_value" in self.case: + if "bool" in dtype: + self.value = bool(self.case["constant_value"]) + elif "int" in dtype: + self.value = int(self.case["constant_value"]) + elif "float" in dtype: + self.value = float(self.case["constant_value"]) + else: + self.value = self.random(self.case["shape"], dtype).tolist() + self.dtype = dtype def build_paddle_program(self, target): x = paddle.to_tensor(self.value, dtype=self.dtype) - self.paddle_outputs = [x] def build_cinn_program(self, target): builder = NetBuilder("constant") x = builder.constant(self.value, self.name, self.dtype) - prog = builder.build() res = self.get_cinn_output(prog, target, [], [], [x]) - self.cinn_outputs = [res[0]] + self.cinn_outputs = res def test_check_results(self): self.check_outputs_and_grads(all_equal=True) -class TestConstantCase1(TestConstantOp): - def init_case(self): - self.value = [1.0] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase2(TestConstantOp): - def init_case(self): - self.value = [1.0, 2.0, 3.0, 4.0, 5.0] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase3(TestConstantOp): - def init_case(self): - self.value = [[1.0, 2.0], [3.0, 4.0]] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase4(TestConstantOp): - def init_case(self): - self.value = [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase5(TestConstantOp): - def init_case(self): - self.value = [[[1.0], [3.0]], [[5.0], [7.0]]] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase6(TestConstantOp): - def init_case(self): - self.value = [[[1.0]]] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase7(TestConstantOp): - def init_case(self): - self.value = self.random([200], "int32", 1, 1000).tolist() - self.name = 'x' - self.dtype = "int32" - - -class TestConstantCase8(TestConstantOp): - def init_case(self): - self.value = self.random([10], "int64", 1, 1000).tolist() - self.name = 'x' - self.dtype = "int64" +class TestConstantOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConstantOpShape" + self.cls = TestConstantOp + self.inputs = [ + { + "constant_value": 10, + }, + { + "constant_value": -5, + }, + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + # known issue: https://github.com/PaddlePaddle/CINN/pull/1453 + # The compilation time is particularly long for AssignValue op. + # { + # "shape": [16, 4, 8, 32], + # }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + # very slow for the shape 2048 + { + "shape": [2048], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] + + +class TestConstantOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConstantOpDtype" + self.cls = TestConstantOp + self.inputs = [ + { + "constant_value": 1, + }, + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + { + "dtype": "bool" + }, + { + "dtype": "uint8" + }, + { + "dtype": "int8" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestConstantOpShape().run() + TestConstantOpDtype().run() diff --git a/python/tests/ops/test_fill_constant_op.py b/python/tests/ops/test_fill_constant_op.py index a42fdea4b9..8f8a1bcb4b 100644 --- a/python/tests/ops/test_fill_constant_op.py +++ b/python/tests/ops/test_fill_constant_op.py @@ -14,145 +14,246 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import paddle.nn.functional as F -import cinn -from cinn.frontend import * +import numpy as np from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestFillConstantOp(OpTest): def setUp(self): - self.init_case() - - def init_case(self): - self.shape = [32] - self.value = 1.0 - self.dtype = "float32" + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + self.shape = self.case["shape"] + self.value = self.case["value"] + self.dtype = self.case["dtype"] + if isinstance(self.value, str): + dtypes = ["bool", "int", "float"] + for dtype in dtypes: + if dtype in self.dtype: + try: + self.value = eval(f"{dtype}(self.value)") + except: + self.value = eval(f"{dtype}(0)") def build_paddle_program(self, target): - x = paddle.full(self.shape, self.value, dtype=self.dtype) + if self.dtype == None: + x = np.full(self.shape, self.value) + x = paddle.to_tensor(x) + else: + x = paddle.full(self.shape, self.value, dtype=self.dtype) self.paddle_outputs = [x] def build_cinn_program(self, target): builder = NetBuilder("fill_constant") - x = builder.fill_constant(self.shape, self.value, "out", self.dtype) + if self.dtype == None: + x = builder.fill_constant(self.shape, self.value, "out") + else: + x = builder.fill_constant(self.shape, self.value, "out", + self.dtype) prog = builder.build() res = self.get_cinn_output(prog, target, [], [], [x]) - self.cinn_outputs = [res[0]] + self.cinn_outputs = res def test_check_results(self): self.check_outputs_and_grads(all_equal=True) -class TestFillConstantCase1(TestFillConstantOp): - def init_case(self): - self.shape = [10, 32, 4] - self.value = 1.0 - self.dtype = "float32" - - -class TestFillConstantCase2(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = 1 - self.dtype = "int32" - - -class TestFillConstantCase3(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = True - self.dtype = "bool" - - -class TestFillConstantCase4(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = int(1) - self.dtype = "uint8" - - -class TestFillConstantCase5(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = int(1) - self.dtype = "int16" - - -class TestFillConstantStringValue(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = "0.12345678987654321" - self.dtype = "float64" - - -class TestFillConstantStringValueCase1(TestFillConstantStringValue): - def init_case(self): - self.shape = [32] - self.value = "0.12345678987654321" - self.dtype = "float16" - - -class TestFillConstantStringValueCase2(TestFillConstantStringValue): - def init_case(self): - self.shape = [32] - self.value = "123456789" - self.dtype = "int64" - - -@OpTestTool.skip_if(not is_compiled_with_cuda(), - "x86 test will be skipped due to timeout.") -class TestFillConstantByValueOp(OpTest): - def setUp(self): - self.init_case() - - def init_case(self): - self.shape = [32] - self.value = float(1.0) - self.dtype = "float32" - - def build_paddle_program(self, target): - x = paddle.full(self.shape, self.value, dtype=self.dtype) - - self.paddle_outputs = [x] - - def build_cinn_program(self, target): - builder = NetBuilder("fill_constant") - x = builder.fill_constant(self.shape, self.value, "out") - - prog = builder.build() - res = self.get_cinn_output(prog, target, [], [], [x]) - - self.cinn_outputs = [res[0]] - - def test_check_results(self): - self.check_outputs_and_grads(all_equal=True) - - -class TestFillConstantByValueCase1(TestFillConstantByValueOp): - def init_case(self): - self.shape = [32] - self.value = int(1) - # only for paddle.full - self.dtype = "int32" - - -class TestFillConstantByValueCase2(TestFillConstantByValueOp): - def init_case(self): - self.shape = [32] - self.value = bool(True) - # only for paddle.full - self.dtype = "bool" +class TestFillConstantOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestFillConstantOpShape" + self.cls = TestFillConstantOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + { + "shape": [16, 4, 8, 32], + }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "value": 123.456 + }, + ] + + +class TestFillConstantOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestFillConstantOpDtype" + self.cls = TestFillConstantOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + { + "dtype": "bool" + }, + { + "dtype": "uint8" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + ] + self.attrs = [ + { + "value": 123.456 + }, + ] + + +class TestFillConstantOpValue(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestFillConstantOpValue" + self.cls = TestFillConstantOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + ] + self.dtypes = [ + { + "dtype": None + }, + ] + self.attrs = [ + { + "value": bool(True) + }, + { + "value": int(123) + }, + { + "value": float(123.456) + }, + ] + + +class TestFillConstantOpStrValue(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestFillConstantOpStrValue" + self.cls = TestFillConstantOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + { + "dtype": "bool" + }, + { + "dtype": "uint8" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + ] + self.attrs = [ + { + "value": "1024" + }, + { + "value": "0.12345678987654321" + }, + ] if __name__ == "__main__": - unittest.main() + TestFillConstantOpShape().run() + TestFillConstantOpDtype().run() + TestFillConstantOpValue().run() + TestFillConstantOpStrValue().run()