From 416712a692964bbf00a257743c06c2c4d35071c7 Mon Sep 17 00:00:00 2001 From: BiynXu <62832681+BiynXu@users.noreply.github.com> Date: Mon, 19 Jun 2023 14:03:25 +0800 Subject: [PATCH] Update release v03 (#1528) * op unittest for cbrt/ceil/cholesky/concat/constant/fill_constant (#1495) * op unittest for cbrt * op unittest for ceil * op unittest for cholesky * op unittest for concat * op unittest for constant * add 4d test for constant op * fix ci * op unittest for fill_constant * op unittest for fill_constant * refine * fix(schedule): fix SimpleComputeAt primitive (#1504) * Fix reduce cast schedule bug (#1512) * fix(fuse): fix reduce cast schedule bug * test(fuse): add unittest for reduce_cast subgroup * Refactor some op tests and fix bugs (#1515) * Add depthwise_conv2d op test * Refactor log op test * Refactor round op test and fix bugs * Only test depthwise_conv2d in cuda_cudnn * op unittest for repeat/arange/reverse/elementwise_add_grad/flip (#1514) * op unittest for repeat op * add repeat frontend * op unittest for repeat * op unittest for arange * op unittest for reverse * format & remove old add op test * op unittest for flipe && remove redundant flip implementation * remove test_add_op_new.py * update reverse * Refactor some op tests and fix bugs (#1513) * Refactor op isclose test * Refactor op logical_right_shift and add more dtypes support * Refactor pow op test and fix bugs * Refactor lookup_table op test * Add logical_right_shift host function proto * Improve isclose test case * Fixed jitify commit to prevent header file conflicts (#1522) * Fixed jitify commit to prevent header file conflicts * Set random seed for debug floor_divide * Avoid oom error * Just for debug ci * Fix floor_divide error when input dtype is int * Fix bugs and add more tests for floor_divide * Experimental PR for the first OP to clean old schedule (#1524) --------- Co-authored-by: zzk0 <30856589+zzk0@users.noreply.github.com> Co-authored-by: Fisher Co-authored-by: Huihuang Zheng --- cinn/frontend/net_builder.cc | 17 +- cinn/frontend/net_builder.h | 23 +- cinn/frontend/net_builder_test.cc | 70 ---- cinn/hlir/framework/op_lowering_util.cc | 63 +++- cinn/hlir/op/contrib/CMakeLists.txt | 2 - cinn/hlir/op/contrib/argmin.cc | 63 ++-- cinn/hlir/op/contrib/flip.cc | 118 ------ cinn/hlir/op/contrib/flip.h | 32 -- cinn/hlir/op/contrib/flip_test.cc | 67 ---- cinn/hlir/op/contrib/repeat_test.cc | 4 +- cinn/hlir/op/elementwise.cc | 25 +- cinn/hlir/op/transform.cc | 9 - cinn/hlir/op/use_ops.h | 1 - cinn/ir/ir_schedule.cc | 32 +- cinn/lang/builtin.cc | 12 +- cinn/pybind/frontend.cc | 38 +- cinn/runtime/cpu/host_intrinsics.cc | 20 +- cinn/runtime/cpu/host_intrinsics.h | 4 + .../runtime/cuda/cinn_cuda_runtime_source.cuh | 21 +- cinn/runtime/cuda/cuda_intrinsics.cc | 6 +- cmake/external/jitify.cmake | 2 +- python/tests/fusion/test_reduce_cast.py | 39 ++ python/tests/ops/op_test_helper.py | 13 +- python/tests/ops/test_add_op.py | 263 ++++++++++---- python/tests/ops/test_add_op_new.py | 271 -------------- python/tests/ops/test_arange_op.py | 191 ++++++++++ python/tests/ops/test_cbrt_op.py | 137 ++++--- python/tests/ops/test_ceil_op.py | 109 +++++- python/tests/ops/test_cholesky_op.py | 220 ++++++++++-- python/tests/ops/test_concat_op.py | 336 +++++++++++++++--- python/tests/ops/test_constant_op.py | 191 ++++++---- python/tests/ops/test_depthwise_conv2d_op.py | 192 ++++++++++ python/tests/ops/test_fill_constant_op.py | 331 +++++++++++------ python/tests/ops/test_floor_divide_op.py | 182 +++------- python/tests/ops/test_isclose_op.py | 232 ++++++++---- python/tests/ops/test_log_op.py | 145 ++++++++ .../tests/ops/test_logical_right_shift_op.py | 126 ++++--- python/tests/ops/test_lookup_table_op.py | 99 ++++-- python/tests/ops/test_pow_op.py | 158 +++++--- python/tests/ops/test_repeat_op.py | 267 ++++++++++++++ python/tests/ops/test_reverse_op.py | 311 ++++++++++++++++ python/tests/ops/test_round_op.py | 112 ++++++ python/tests/ops/test_sign_op.py | 4 +- 43 files changed, 3150 insertions(+), 1408 deletions(-) delete mode 100644 cinn/hlir/op/contrib/flip.cc delete mode 100644 cinn/hlir/op/contrib/flip.h delete mode 100644 cinn/hlir/op/contrib/flip_test.cc create mode 100644 python/tests/fusion/test_reduce_cast.py delete mode 100644 python/tests/ops/test_add_op_new.py create mode 100644 python/tests/ops/test_arange_op.py create mode 100644 python/tests/ops/test_depthwise_conv2d_op.py create mode 100644 python/tests/ops/test_log_op.py create mode 100644 python/tests/ops/test_repeat_op.py create mode 100755 python/tests/ops/test_reverse_op.py create mode 100644 python/tests/ops/test_round_op.py diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 07238acadc..0d04897d1d 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -246,17 +246,6 @@ Placeholder NetBuilder::CreateInput(const Variable& var) { return Placeholder(var); } -Variable NetBuilder::FillConstant( - const std::vector& shape, float value, const std::string& name, const std::string& dtype, bool force_cpu) { - auto out = - CustomInstr("fill_constant", {}, {{"shape", shape}, {"value", value}, {"dtype", dtype}, {"force_cpu", force_cpu}}) - .front(); - if (!name.empty()) { - out.set_id(cinn::utils::TransValidVarName(name)); - } - return out; -} - Variable NetBuilder::FillConstant(const std::vector& shape, const std::string& str_value, const std::string& name, @@ -827,11 +816,7 @@ Variable NetBuilder::Arange(const float start, const float stop, const float ste } Variable NetBuilder::Flip(const Variable& operand, const std::vector& axes) { - Instruction instr("flip", {operand}); - instr.SetAttr("axes", axes); - InferShape(instr); - AppendInstruction(instr); - return instr.GetOutput(0); + return CustomInstr("reverse", {operand}, {{"axis", utils::GetPositiveAxes(axes, operand->shape.size())}}).front(); } Variable NetBuilder::Matmul(const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha) { diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h index b798a7af60..b16b9a91a4 100644 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -350,7 +350,7 @@ class NetBuilder { const std::string& id_hint = ""); /** - * @brief Create constant tensor with the specific value/vector and type, the type is infered from value. + * @brief Create constant tensor with the specific value/vector and type * @param value The constant value to be set. * @param name The name of output variable. * @return The result variable. @@ -408,11 +408,21 @@ class NetBuilder { * @param force_cpu Whether the variable should force placed in cpu, default in device memory. Default is false. * @return The result variable. */ + template Variable FillConstant(const cinn::utils::ShapeType& shape, - float value, + T value, const std::string& name, const std::string& dtype, - bool force_cpu = false); + bool force_cpu = false) { + auto out = + CustomInstr( + "fill_constant", {}, {{"shape", shape}, {"value", value}, {"dtype", dtype}, {"force_cpu", force_cpu}}) + .front(); + if (!name.empty()) { + out.set_id(cinn::utils::TransValidVarName(name)); + } + return out; + } /** * @brief The op return a variable with the specific string value, shape and type. @@ -442,7 +452,7 @@ class NetBuilder { T value, const std::string& name = "", bool force_cpu = false) { - return FillConstant(shape, static_cast(value), name, common::Type2Str(common::type_of()), force_cpu); + return FillConstant(shape, value, name, common::Type2Str(common::type_of()), force_cpu); } /** @@ -891,7 +901,10 @@ class NetBuilder { const std::string& padding_algorithm = "EXPLICIT"); /** - * This API flipes the Variable x along the given axis. + * @brief This API reverse the Variable x along the given axis. + * @param x An N-D variable. + * @param axis Specify the axis to operate on the input reverse. + * @return A reversed variable with the same data type as x. */ Variable Flip(const Variable& operand, const std::vector& axes); diff --git a/cinn/frontend/net_builder_test.cc b/cinn/frontend/net_builder_test.cc index 1fb87e6a95..e57ec7a241 100644 --- a/cinn/frontend/net_builder_test.cc +++ b/cinn/frontend/net_builder_test.cc @@ -984,76 +984,6 @@ TEST(net_build, program_execute_arange_int) { } } -TEST(net_build, program_execute_flip) { - const int C = 2; - const int H = 2; - const int W = 2; - const std::vector axes{0}; - - NetBuilder builder("net_builder"); - Placeholder input = builder.CreateInput(Float(32), {C, H, W}, "Img"); - Variable output = builder.Flip(input, axes); - auto program = builder.Build(); - -#ifdef CINN_WITH_CUDA - Target target = common::DefaultNVGPUTarget(); -#else - Target target = common::DefaultHostTarget(); -#endif - std::unordered_set fetch_ids; - auto graph = Optimize(&program, fetch_ids, target); - - auto scope = BuildScope(target, graph); - hlir::framework::GraphCompiler gc(target, scope, graph); - auto runtime_program = gc.Build(); - - scope->Var(std::string(input.id())); - scope->Var(std::string(output->id)); - - auto input_tensor = scope->GetTensor(std::string(input.id())); - SetRandData(input_tensor, target); - std::vector input_data = GetTensorData(input_tensor, target); - - runtime_program->Execute(); - auto output_tensor = scope->GetTensor(std::string(output->id)); - const std::vector& output_shape = output_tensor->shape().data(); - EXPECT_EQ(output_tensor->type(), Float(32)); - EXPECT_EQ(output_shape.size(), 3UL); - EXPECT_EQ(output_shape[0], C); - EXPECT_EQ(output_shape[1], H); - EXPECT_EQ(output_shape[2], W); - - std::vector output_data = GetTensorData(output_tensor, target); - VLOG(6) << "Visualize flip input_data"; - for (int c = 0; c < C; c++) { - for (int h = 0; h < H; h++) { - std::string line; - for (int w = 0; w < W; w++) { - int index = c * (H * W) + h * W + w; - line += (std::to_string(index) + ": " + std::to_string(input_data[index]) + ", "); - } - VLOG(6) << line; - } - } - - VLOG(6) << "Visualize flip output_data"; - for (int c = 0; c < C; c++) { - int flip_c = std::find(axes.begin(), axes.end(), 0) == axes.end() ? c : C - c - 1; - for (int h = 0; h < H; h++) { - std::string line; - int flip_h = std::find(axes.begin(), axes.end(), 1) == axes.end() ? h : H - h - 1; - for (int w = 0; w < W; w++) { - int flip_w = std::find(axes.begin(), axes.end(), 2) == axes.end() ? w : W - w - 1; - int flip_index = flip_c * H * W + flip_h * W + flip_w; - int index = c * (H * W) + h * W + w; - line += (std::to_string(index) + ": " + std::to_string(output_data[index]) + ", "); - EXPECT_EQ(input_data[index], output_data[flip_index]); - } - VLOG(6) << line; - } - } -} - TEST(net_build, program_argmax_case1) { const int N = 4; const int IN_C = 3; diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index ee4327451a..8220c58600 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -935,10 +935,10 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch, }; auto node_shape = shape_dict.at(node_data->id()); - // node output is same shape with reduce output. + // The output shape of node is different from that of reduce node if (std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) != std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies())) { - // split loop to assign master loop + // get loop factors of reduce node int extend = 1; std::vector factors; loops = ir_sch.GetLoops(node_data->id()); @@ -953,8 +953,63 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch, factors.push_back(loop.As()->extent.as_int32()); } - ir_sch.Split(loops.back(), factors); - loops = ir_sch.GetLoops(node_data->id()); + // If there are IfThenElse stmt in loop, we need to find out the indices in condition, + // and special treatment should be applied to loops with these indices. + // We apply two step split on loop of src node to align the loop of reduce node. + std::unordered_set loop_index_in_if; + auto first_reduce_loop = rloops.front(); + // collect if + auto if_checker = [](const Expr* x) { return x->As(); }; + auto if_set = ir::CollectIRNodesWithoutTensor(first_reduce_loop.As()->body, if_checker); + std::string reduce_block_name = reducer_data->id(); + for (auto if_expr : if_set) { + auto checker = [reduce_block_name](const Expr* x) { + return x->As() && + x->As()->schedule_block.As()->name == reduce_block_name; + }; + auto blocks_in_if = ir::CollectIRNodesWithoutTensor(if_expr, checker); + if (!blocks_in_if.empty()) { + ir::Expr condition = if_expr.As()->condition; + auto indices_in_if = + ir::CollectIRNodesWithoutTensor(condition, [](const Expr* x) { return x->As(); }); + for (int i = 0; i < rloops.size(); ++i) { + std::string var_name = rloops[i].As()->loop_var->name; + auto find_var_iter = std::find_if(indices_in_if.begin(), indices_in_if.end(), [&var_name](const ir::Expr& x) { + return x.As()->name == var_name; + }); + if (find_var_iter != indices_in_if.end()) { + loop_index_in_if.insert(i); + } + } + break; + } + } + + // prepare factors of two step split + std::vector first_step_factors; + std::vector second_step_factors; + int second_start_loop_index; + for (int i = 0; i < factors.size(); ++i) { + if (loop_index_in_if.count(i) == 0) { + first_step_factors.push_back(factors[i]); + } else if (loop_index_in_if.count(i) != 0 && second_step_factors.empty()) { + first_step_factors.push_back(-1); + second_step_factors.push_back(factors[i]); + second_start_loop_index = i; + } else if (loop_index_in_if.count(i) != 0 && !second_step_factors.empty()) { + second_step_factors.push_back(factors[i]); + } + } + // do two step split + if (!first_step_factors.empty()) { + ir_sch.Split(loops.back(), first_step_factors); + loops = ir_sch.GetLoops(node_data->id()); + } + if (!second_step_factors.empty()) { + ir_sch.Split(loops.at(second_start_loop_index), second_step_factors); + loops = ir_sch.GetLoops(node_data->id()); + } + // copy loop info form rloops. copy_loop_info(loops, rloops); return; diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index 48565a4edb..d8237fb503 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -2,7 +2,6 @@ core_gather_headers() gather_srcs(cinnapi_src SRCS gather_nd.cc - flip.cc sort.cc argmin.cc argmax.cc @@ -24,7 +23,6 @@ cc_test(test_gather_nd SRCS gather_nd_test.cc DEPS cinncore) cc_test(test_sort SRCS sort_test.cc DEPS cinncore) cc_test(test_argmin SRCS argmin_test.cc DEPS cinncore) cc_test(test_argmax SRCS argmax_test.cc DEPS cinncore) -cc_test(test_flip SRCS flip_test.cc DEPS cinncore) cc_test(test_repeat SRCS repeat_test.cc DEPS cinncore) cc_test(test_one_hot SRCS one_hot_test.cc DEPS cinncore) cc_test(test_lookup_table SRCS lookup_table_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/contrib/argmin.cc b/cinn/hlir/op/contrib/argmin.cc index 51214f30eb..7ad8b3e76d 100644 --- a/cinn/hlir/op/contrib/argmin.cc +++ b/cinn/hlir/op/contrib/argmin.cc @@ -113,18 +113,15 @@ std::shared_ptr StrategyForArgmin(const framework::NodeAt framework::CINNCompute argmin_compute([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of argmin compute is empty! Please check."; common::CINNValuePack pack_args = args[0]; - std::string tensor_name = UniqName("Argmin_out"); CHECK_GE(pack_args.size(), 1U) << "There should be 1 input args for argmax compute"; Expr in_expr = pack_args[0]; CHECK(in_expr.as_tensor()); Tensor in_tensor = in_expr.as_tensor_ref(); auto stages = CreateStages({in_tensor}); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2U); - CHECK(pack_args[1].is_string()); - tensor_name = pack_args[1].operator std::string(); - } - auto out_tensor = Argmin(in_tensor, target, stages, axis, keep_dims, tensor_name); + CHECK_EQ(pack_args.size(), 2U); + CHECK(pack_args[1].is_string()); + std::string tensor_name = pack_args[1].operator std::string(); + auto out_tensor = Argmin(in_tensor, target, stages, axis, keep_dims, tensor_name); stages->InsertLazily(out_tensor[0]); std::vector cinn_values{ @@ -133,38 +130,30 @@ std::shared_ptr StrategyForArgmin(const framework::NodeAt }); framework::CINNSchedule argmin_schedule([=](lang::Args args, lang::RetValue *ret) { - if (FLAGS_cinn_ir_schedule) { - CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n"; - common::CINNValuePack arg_pack = args[0]; - std::vector vec_ast; - for (int i = 0; i < arg_pack.size(); i++) { - if (arg_pack[i].is_expr()) { - Expr temp = arg_pack[i]; - vec_ast.emplace_back(temp); - } - } - CHECK(!vec_ast.empty()); - ir::ModuleExpr mod_expr(vec_ast); - ir::IRSchedule ir_sch(mod_expr); - ir_sch.MergeExprs(); - auto blocks = ir_sch.GetAllBlocks(); - // TODO: It needs to be rewritten according to the reduction_min operator to improve performance. - // Do not use local variables, because the size will exceed the limit. - ir_sch.SetBuffer(blocks[0], "local"); - ir_sch.SetBuffer(blocks[1], "local"); - long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); - if (prod_size > 1 && target.arch == Target::Arch::X86) { - pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); + CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n"; + common::CINNValuePack arg_pack = args[0]; + std::vector vec_ast; + for (int i = 0; i < arg_pack.size(); i++) { + if (arg_pack[i].is_expr()) { + Expr temp = arg_pack[i]; + vec_ast.emplace_back(temp); } - std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; - *ret = common::CINNValuePack{res}; - } else { - CHECK(!args.empty()) << "The input argument of arange_schedule is empty! Please check.\n"; - common::CINNValuePack arg_pack = args[0]; - Expr out = arg_pack[0]; - CHECK(out.as_tensor()); - *ret = arg_pack; } + CHECK(!vec_ast.empty()); + ir::ModuleExpr mod_expr(vec_ast); + ir::IRSchedule ir_sch(mod_expr); + ir_sch.MergeExprs(); + auto blocks = ir_sch.GetAllBlocks(); + // TODO: It needs to be rewritten according to the reduction_min operator to improve performance. + // Do not use local variables, because the size will exceed the limit. + ir_sch.SetBuffer(blocks[0], "local"); + ir_sch.SetBuffer(blocks[1], "local"); + long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies()); + if (prod_size > 1 && target.arch == Target::Arch::X86) { + pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true); + } + std::vector res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))}; + *ret = common::CINNValuePack{res}; }); auto strategy = std::make_shared(); diff --git a/cinn/hlir/op/contrib/flip.cc b/cinn/hlir/op/contrib/flip.cc deleted file mode 100644 index 8157266ff5..0000000000 --- a/cinn/hlir/op/contrib/flip.cc +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) 2022 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cinn/hlir/op/contrib/flip.h" - -#include - -#include -#include -#include -#include - -#include "cinn/common/cas.h" -#include "cinn/common/common.h" -#include "cinn/common/context.h" -#include "cinn/common/macros.h" -#include "cinn/hlir/framework/node.h" -#include "cinn/hlir/framework/op.h" -#include "cinn/hlir/framework/op_strategy.h" -#include "cinn/hlir/op/op_util.h" -#include "cinn/hlir/pe/elementwise.h" -#include "cinn/hlir/pe/ir_schedule_pe.h" -#include "cinn/hlir/pe/transform.h" -#include "cinn/ir/ir.h" -#include "cinn/ir/ir_base.h" -#include "cinn/ir/ir_schedule.h" -#include "cinn/ir/tensor.h" -#include "cinn/lang/builtin.h" -#include "cinn/lang/compute.h" - -DECLARE_bool(cinn_ir_schedule); - -namespace cinn { -namespace hlir { -namespace op { - -using common::CINNValue; -using common::CINNValuePack; -using framework::shape_t; - -ir::Tensor Flip(const ir::Tensor &input, const std::vector &axes, const std::string &name) { - return cinn::hlir::pe::Reverse(input, axes, name); -} - -std::shared_ptr StrategyForFlip(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - CHECK(attrs.attr_store.count("axes")) << "find no attr of axes"; - std::vector axes = absl::get>(attrs.attr_store.at("axes")); - std::string op_name("flip"); - - framework::CINNCompute flip_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check."; - CINNValuePack pack_args = args[0]; - CHECK_GE(pack_args.size(), 1U) << "1 input tensor for " << op_name << " compute"; - std::string tensor_name = UniqName(op_name + "_Out"); - if (FLAGS_cinn_ir_schedule) { - CHECK_EQ(pack_args.size(), 2U); - tensor_name = pack_args[1].operator std::string(); - } - Expr A_expr = pack_args[0]; - CHECK(A_expr.as_tensor()); - ir::Tensor A = A_expr.as_tensor_ref(); - auto out = Flip(A, axes, tensor_name); - auto stages = CreateStages({A}); - std::vector res; - stages->InsertLazily(out); - res.push_back(CINNValue(out)); - res.push_back(CINNValue(stages)); - *ret = CINNValuePack{res}; - }); - - auto strategy = std::make_shared(); - strategy->AddImpl(flip_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.flip.x86", 1); - return strategy; -} - -std::vector InferShapeForFlip(const std::vector &inputs_shape, const framework::AttrMapType &attrs) { - CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again."; - std::vector res{inputs_shape[0]}; - return res; -} - -std::vector InferDtypeForFlip(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; - std::vector res{inputs_type[0]}; - return res; -} - -} // namespace op -} // namespace hlir -} // namespace cinn - -CINN_REGISTER_HELPER(flip_ops) { - CINN_REGISTER_OP(flip) - .describe("Flip.") - .set_num_inputs(1) - .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForFlip) - .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForFlip)) - .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForFlip)) - .set_support_level(4); - - return true; -} diff --git a/cinn/hlir/op/contrib/flip.h b/cinn/hlir/op/contrib/flip.h deleted file mode 100644 index 8c52cf0449..0000000000 --- a/cinn/hlir/op/contrib/flip.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2022 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "cinn/ir/ir.h" -#include "cinn/ir/ir_base.h" -#include "cinn/ir/tensor.h" - -namespace cinn { -namespace hlir { -namespace op { - -ir::Tensor Flip(const ir::Tensor& input, const std::vector& axis, const std::string& name); - -} // namespace op -} // namespace hlir -} // namespace cinn diff --git a/cinn/hlir/op/contrib/flip_test.cc b/cinn/hlir/op/contrib/flip_test.cc deleted file mode 100644 index e86c9e29ed..0000000000 --- a/cinn/hlir/op/contrib/flip_test.cc +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2022 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cinn/hlir/op/contrib/flip.h" - -#include -#include - -#include -#include - -#include "cinn/backends/codegen_c.h" -#include "cinn/backends/codegen_c_x86.h" -#include "cinn/backends/codegen_cuda_dev.h" -#include "cinn/common/context.h" -#include "cinn/lang/lower.h" -#include "cinn/lang/placeholder.h" -#include "cinn/poly/stage.h" - -namespace cinn { -namespace hlir { -namespace op { - -TEST(GenerateCode_Cpu, Flip) { - common::Context::Global().ResetNameId(); - - common::Target target = common::DefaultHostTarget(); - - ir::Expr n(4); - ir::Expr h(28); - - lang::Placeholder in("in", {n, h}); - ir::Tensor res = Flip(in, {1}, "test_flip"); - - poly::StageMap stages = poly::CreateStages({res}); - std::vector funcs = - lang::LowerVec("TestGenerateCodeCpu_Flip", stages, {res}, {}, {}, nullptr, target, true); - - VLOG(6) << "Expr before CPU codegen:"; - VLOG(6) << funcs[0]->body; - - ir::Module::Builder builder("Flip_Module", target); - for (auto& f : funcs) { - builder.AddFunction(f); - } - - backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512); - codegen.SetInlineBuiltinCodes(false); - std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl); - VLOG(6) << "Cpu Codegen result:"; - VLOG(6) << code << std::endl; -} - -} // namespace op -} // namespace hlir -} // namespace cinn diff --git a/cinn/hlir/op/contrib/repeat_test.cc b/cinn/hlir/op/contrib/repeat_test.cc index 71aebe22e2..02977ea19f 100755 --- a/cinn/hlir/op/contrib/repeat_test.cc +++ b/cinn/hlir/op/contrib/repeat_test.cc @@ -65,7 +65,7 @@ function TestGenerateCodeCpu_Repeat (_test_repeat) ScheduleBlock(test_repeat) { i0, i1 = axis.bind(i, j) - test_repeat[i0, i1] = in[(i0 / 2), i1] + test_repeat[i0, i1] = in[select((((i0 > 0) and (2 > 0)) or ((i0 < 0) and (2 < 0))), (i0 / 2), select(((i0 % 2) == 0), (i0 / 2), ((i0 / 2) - 1))), i1] } } } @@ -100,7 +100,7 @@ void TestGenerateCodeCpu_Repeat(void* _args, int32_t num_args) int32_t* test_repeat = ((int32_t*)(_test_repeat->memory)); for (int32_t i = 0; i < 8; i += 1) { for (int32_t j = 0; j < 4; j += 1) { - test_repeat[((4 * i) + j)] = in[(((i / 2) * 4) + j)]; + test_repeat[((4 * i) + j)] = in[((4 * (((((i > 0) && (2 > 0)) || ((i < 0) && (2 < 0)))) ? (i / 2) : ((((i & 1) == 0)) ? (i / 2) : ((i / 2) + -1)))) + j)]; }; }; cinn_buffer_free((void*)(0), _in); diff --git a/cinn/hlir/op/elementwise.cc b/cinn/hlir/op/elementwise.cc index 1eafbe5ec3..598311837a 100644 --- a/cinn/hlir/op/elementwise.cc +++ b/cinn/hlir/op/elementwise.cc @@ -201,6 +201,7 @@ std::shared_ptr StrategyForConstScalar(const framework::NodeAttr &at framework::CINNCompute const_scalar_compute([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of const_float compute is empty! Please check."; auto scalar = GetScalarExpr(attrs.attr_store.at("value")); + auto scalar_type = out_type.at(0); CINNValuePack pack_args = args[0]; std::string tensor_name = UniqName("const_scalar_Out"); if (FLAGS_cinn_ir_schedule) { @@ -210,7 +211,12 @@ std::shared_ptr StrategyForConstScalar(const framework::NodeAttr &at } auto out = lang::Compute( - {Expr(1)}, [=](const std::vector &indice) { return scalar; }, tensor_name); + {Expr(1)}, + [=](const std::vector &indice) { + auto res = (scalar_type == scalar->type()) ? scalar : ir::Cast::Make(scalar_type, scalar); + return res; + }, + tensor_name); CHECK(out.defined()) << "can't create const scalar with the given type " << out_type[0]; auto stages = CreateStages({out}); *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; @@ -229,9 +235,16 @@ std::vector InferShapeForConstScalar(const std::vector &inputs } std::vector InferDtypeForConstScalar(const std::vector &inputs_type, const framework::AttrMapType &attrs) { - CHECK(attrs.count("value")); - auto scalar = GetScalarExpr(attrs.at("value")); - auto out_type = scalar->type(); + Type out_type; + if (attrs.find("dtype") != attrs.end()) { + auto dtype_str = absl::get(attrs.at("dtype")); + if (!dtype_str.empty()) { + out_type = common::Str2Type(dtype_str); + } + } else { + auto scalar = GetScalarExpr(attrs.at("value")); + out_type = scalar->type(); + } VLOG(3) << "scalar type: " << out_type; return {out_type}; } @@ -356,10 +369,10 @@ std::vector> InferLayoutForFillConstant(const std::vect #define EXPAND_ATTR_TYPE(MACRO) \ MACRO(bool) \ - MACRO(float) \ MACRO(int) \ MACRO(int64_t) \ - MACRO(double) + MACRO(double) \ + MACRO(float) std::shared_ptr StrategyForAssignValue(const framework::NodeAttr &attrs, const std::vector &inputs, diff --git a/cinn/hlir/op/transform.cc b/cinn/hlir/op/transform.cc index 31f3fdc3ee..24951be324 100644 --- a/cinn/hlir/op/transform.cc +++ b/cinn/hlir/op/transform.cc @@ -831,7 +831,6 @@ std::shared_ptr StrategyForReverse(const framework::NodeAttr &attrs, std::vector axis; if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { axis = absl::get>(attrs.attr_store.at("axis")); - CHECK(!axis.empty()) << "axis is empty! Please check setting.\n"; for (auto &e : axis) { if (e >= static_cast(output_shapes[0].size()) || e < -1 * static_cast(output_shapes[0].size())) { LOG(FATAL) << "axis is not in [0, n_dim), Please check."; @@ -840,8 +839,6 @@ std::shared_ptr StrategyForReverse(const framework::NodeAttr &attrs, e += output_shapes[0].size(); } } - } else { - LOG(FATAL) << "axis is not be set! Please check."; } framework::CINNCompute reverse_compute([=](lang::Args args, lang::RetValue *ret) { @@ -875,7 +872,6 @@ std::vector InferShapeForReverse(const std::vector res{inputs_shape[0]}; if (attrs.find("axis") != attrs.end()) { auto axis = absl::get>(attrs.at("axis")); - CHECK(!axis.empty()) << "axis is empty! Please check setting.\n"; for (auto &e : axis) { if (e >= static_cast(inputs_shape[0].size()) || e < -1 * static_cast(inputs_shape[0].size())) { LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check."; @@ -884,8 +880,6 @@ std::vector InferShapeForReverse(const std::vector> InferLayoutForReverse(const std::vector>(attrs.attr_store.at("axis")); - CHECK(!axis.empty()) << "axis is empty! Please check setting.\n"; for (auto &e : axis) { if (e >= static_cast(input_shapes[0].size()) || e < -1 * static_cast(input_shapes[0].size())) { LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check."; } } - } else { - LOG(FATAL) << "axis is not be set! Please check."; } CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again."; return {input_layouts, input_layouts}; diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index e29941efd7..9589bb96b0 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -28,7 +28,6 @@ CINN_USE_REGISTER(argmin_ops) CINN_USE_REGISTER(argmax_ops) CINN_USE_REGISTER(reduce_ops) CINN_USE_REGISTER(custom_call_op) -CINN_USE_REGISTER(flip_ops) CINN_USE_REGISTER(repeat_ops) CINN_USE_REGISTER(one_hot_ops) CINN_USE_REGISTER(lookup_table_ops) diff --git a/cinn/ir/ir_schedule.cc b/cinn/ir/ir_schedule.cc index 14ed5b4a47..eb2d934e0f 100644 --- a/cinn/ir/ir_schedule.cc +++ b/cinn/ir/ir_schedule.cc @@ -1210,20 +1210,19 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { loops.size() < block_loops.size() ? optim::IRCopy(block_loops[loops.size()]) : optim::IRCopy(this_block); Expr new_loop = optim::IRCopy(this_loop); - if (loops.size() >= block_loops.size()) { - auto body = block_loops.back().As()->body; - // collect if - auto if_checker = [](const Expr* x) { return x->As(); }; - auto if_set = ir::CollectIRNodesWithoutTensor(body, if_checker); - for (auto if_expr : if_set) { - auto checker = [block_name](const Expr* x) { - return x->As() && - x->As()->schedule_block.As()->name == block_name; - }; - if (ir::CollectIRNodesWithoutTensor(if_expr, checker, true).size() > 0) { - result = IfThenElse::Make(if_expr.As()->condition, result); - break; - } + // Get the body of block_loop under the same loops + auto body = block_loops.at(loops.size() - 1).As()->body; + // collect if + auto if_checker = [](const Expr* x) { return x->As(); }; + auto if_set = ir::CollectIRNodesWithoutTensor(body, if_checker); + for (auto if_expr : if_set) { + auto checker = [block_name](const Expr* x) { + return x->As() && + x->As()->schedule_block.As()->name == block_name; + }; + if (ir::CollectIRNodesWithoutTensor(if_expr, checker, true).size() > 0) { + result = IfThenElse::Make(if_expr.As()->condition, result); + break; } } @@ -1238,10 +1237,9 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { ir::Block::Make({result.As()->true_case, new_loop.As()->body.As()->stmts[0].As()->true_case}); } else { - new_loop.As()->body.As()->stmts[0].As()->true_case = ir::Block::Make( - {result, new_loop.As()->body.As()->stmts[0].As()->true_case}); + std::vector::iterator pos = new_loop.As()->body.As()->stmts.begin(); + new_loop.As()->body.As()->stmts.insert(pos, result); } - } else { new_loop.As()->body = ir::Block::Make({result, new_loop.As()->body}); } diff --git a/cinn/lang/builtin.cc b/cinn/lang/builtin.cc index 0abf8dc986..266f704a76 100644 --- a/cinn/lang/builtin.cc +++ b/cinn/lang/builtin.cc @@ -107,7 +107,17 @@ Expr One(const Type& type) { return ir::One(type); } Expr FloorDivide(Expr a, Expr b) { CHECK_EQ(a.type(), b.type()) << "FloorDivide's inputs type not equal, where a:" << a.type() << " but b:" << b.type(); - return a.type().is_float() ? Floor(a / b) : a / b; + if (a.type().is_float()) { + return Floor(a / b); + } else if (a.type().is_uint()) { + return a / b; + } else { + auto div = a / b; + auto mod = a % b; + auto ret = ir::Select::Make( + ir::EQ::Make(mod, common::make_const(a.type(), 0)), div, div - common::make_const(a.type(), 1)); + return ir::Select::Make((a > 0 && b > 0) || (a < 0 && b < 0), div, ret); + } } Expr min_value(const Type& type) { diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index 2e84b094cb..73de15adab 100644 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -65,8 +65,6 @@ static const char *SnakeName(const char *name) { #define EXPAND_CINN_SUPPORT_TYPE(EXPAND_MACRO) \ EXPAND_MACRO(bool) \ - EXPAND_MACRO(float) \ - EXPAND_MACRO(int) \ EXPAND_MACRO(int64_t) \ EXPAND_MACRO(double) @@ -405,14 +403,23 @@ void BindFrontend(pybind11::module *m) { #undef EXPAND_QUINTIC_VECTOR #undef EXPAND_SEXTIC_VECTOR #undef PY_REGISTER_CONSTANT_OP -#define PY_REGISTER_FILLCONSTANT_OP(TYPE__) \ - .def("fill_constant", \ - static_cast &, TYPE__, const std::string &, bool)>( \ - &NetBuilder::template FillConstant), \ - py::arg("shape"), \ - py::arg("value"), \ - py::arg("name") = "", \ +#define PY_REGISTER_FILLCONSTANT_OP(TYPE__) \ + .def("fill_constant", \ + static_cast &, TYPE__, const std::string &, const std::string &, bool)>( \ + &NetBuilder::FillConstant), \ + py::arg("shape"), \ + py::arg("value"), \ + py::arg("name") = "", \ + py::arg("dtype"), \ + py::arg("force_cpu") = false) \ + .def("fill_constant", \ + static_cast &, TYPE__, const std::string &, bool)>( \ + &NetBuilder::template FillConstant), \ + py::arg("shape"), \ + py::arg("value"), \ + py::arg("name") = "", \ py::arg("force_cpu") = false) EXPAND_CINN_SUPPORT_TYPE(PY_REGISTER_FILLCONSTANT_OP) #undef PY_REGISTER_FILLCONSTANT_OP @@ -462,15 +469,6 @@ void BindFrontend(pybind11::module *m) { .def("name", &NetBuilder::name) .def("__str__", [](NetBuilder &self) { return self.name(); }) .def("append_instruction", &NetBuilder::AppendInstruction, py::arg("instr")) - .def("fill_constant", - static_cast &, float, const std::string &, const std::string &, bool)>( - &NetBuilder::FillConstant), - py::arg("shape"), - py::arg("value"), - py::arg("name") = "", - py::arg("dtype"), - py::arg("force_cpu") = false) .def("fill_constant", static_cast &, const std::string &, const std::string &, const std::string &, bool)>( @@ -703,6 +701,8 @@ void BindFrontend(pybind11::module *m) { py::arg("max") = 0, py::arg("seed") = 0, py::arg("dtype") = "int64") + .def("repeat", &NetBuilder::Repeat, py::arg("x"), py::arg("repeats"), py::arg("axis")) + .def("flip", &NetBuilder::Flip, py::arg("x"), py::arg("axis")) .def("cholesky", &NetBuilder::Cholesky, py::arg("x"), py::arg("upper") = false) .def("triangular_solve", &NetBuilder::TriangularSolve, diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index a9d1cd0e38..f6ce5ca108 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -232,11 +232,10 @@ inline double FN_FP64(pow)(double x, double y) { return pow(x, y); } #define FN_INT32(func) cinn_host_##func##_int32 inline int FN_INT32(pow)(int x, int y) { - int res = 1; - for (int i = 0; i < y; ++i) { - res *= x; + if (x == 0 && y < 0) { + return -1; } - return res; + return pow(x, y); } inline int FN_INT32(clz)(int x) { return __builtin_clz(x); } @@ -253,6 +252,10 @@ inline int64_t FN_INT64(clz)(int64_t x) { return __builtin_clzll(x); } inline int64_t FN_INT64(popc)(int64_t x) { return __builtin_popcountll(x); } +inline int64_t FN_INT64(pow)(int64_t x, int64_t y) { return pow(x, y); } + +inline int64_t FN_INT64(logical_right_shift)(int64_t x, int64_t y) { return ((uint64_t)x >> y); } + #undef FN_INT64 } // extern "C" @@ -323,6 +326,15 @@ CINN_REGISTER_HELPER(host_intrinsics) { #undef REGISTER_EXTERN_FUNC_2_IN_1_INT32 +#define REGISTER_EXTERN_FUNC_2_IN_1_INT64(func__) \ + REGISTER_EXTERN_FUNC_2_IN_1_OUT(cinn_host_##func__##_int64, host_target, int64_t, int64_t, int64_t); + + REGISTER_EXTERN_FUNC_2_IN_1_INT64(pow) + + REGISTER_EXTERN_FUNC_2_IN_1_INT64(logical_right_shift) + +#undef REGISTER_EXTERN_FUNC_2_IN_1_INT64 + REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int32, host_target, int, int); REGISTER_EXTERN_FUNC_1_IN_1_OUT(cinn_host_clz_int64, host_target, int64_t, int64_t); diff --git a/cinn/runtime/cpu/host_intrinsics.h b/cinn/runtime/cpu/host_intrinsics.h index ca2e89aabb..385d9fec1c 100644 --- a/cinn/runtime/cpu/host_intrinsics.h +++ b/cinn/runtime/cpu/host_intrinsics.h @@ -95,6 +95,10 @@ inline int64_t FN_INT64(clz)(int64_t x); inline int64_t FN_INT64(popc)(int64_t x); +inline int64_t FN_INT64(pow)(int64_t x, int64_t y); + +inline int64_t FN_INT64(logical_right_shift)(int64_t x, int64_t y); + #undef FN_INT64 #define FN_FP32(func) cinn_host_##func##_fp32 diff --git a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index fab39b1b76..0525e7f4b1 100644 --- a/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -22,6 +22,7 @@ __device__ inline uint8_t FN_UINT8(bitwise_and)(uint8_t a, uint8_t b) { return a __device__ inline uint8_t FN_UINT8(bitwise_or)(uint8_t a, uint8_t b) { return a | b; } __device__ inline uint8_t FN_UINT8(bitwise_xor)(uint8_t a, uint8_t b) { return a ^ b; } __device__ inline uint8_t FN_UINT8(bitwise_not)(uint8_t a) { return ~a; } +__device__ inline uint8_t FN_UINT8(logical_right_shift)(uint8_t a, uint8_t b) { return ((uint8_t)a >> b); } // *************************************************************** // // int8 unary and binary operator @@ -30,6 +31,7 @@ __device__ inline int8_t FN_INT8(bitwise_and)(int8_t a, int8_t b) { return a & b __device__ inline int8_t FN_INT8(bitwise_or)(int8_t a, int8_t b) { return a | b; } __device__ inline int8_t FN_INT8(bitwise_xor)(int8_t a, int8_t b) { return a ^ b; } __device__ inline int8_t FN_INT8(bitwise_not)(int8_t a) { return ~a; } +__device__ inline int8_t FN_INT8(logical_right_shift)(int8_t a, int8_t b) { return ((uint8_t)a >> b); } // *************************************************************** // // int16 unary and binary operator @@ -38,6 +40,7 @@ __device__ inline int16_t FN_INT16(bitwise_and)(int16_t a, int16_t b) { return a __device__ inline int16_t FN_INT16(bitwise_or)(int16_t a, int16_t b) { return a | b; } __device__ inline int16_t FN_INT16(bitwise_xor)(int16_t a, int16_t b) { return a ^ b; } __device__ inline int16_t FN_INT16(bitwise_not)(int16_t a) { return ~a; } +__device__ inline int16_t FN_INT16(logical_right_shift)(int16_t a, int16_t b) { return ((uint16_t)a >> b); } // *************************************************************** // // float32 unary and binary operator @@ -133,11 +136,11 @@ __device__ inline double FN_FP64(mod)(double a, double b) { #define FN_INT32(func) cinn_nvgpu_##func##_int32 __device__ inline int FN_INT32(pow)(int a, int b) { - int res = 1; - for (int i = 0; i < b; ++i) { - res *= a; + if (a == 0 && b < 0) { + return -1; } - return res; + float res = pow(__int2float_rd(a), __int2float_rd(b)); + return __float2int_rn(res); } __device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; } @@ -171,6 +174,7 @@ __device__ inline long long int FN_INT64(bitwise_xor)(long long int a, long long __device__ inline long long int FN_INT64(bitwise_not)(long long int a) { return ~a; } __device__ inline long long int FN_INT64(clz)(long long int a) { return __clzll(a); } __device__ inline long long int FN_INT64(popc)(long long int a) { return __popcll(a); } +__device__ inline long long int FN_INT64(logical_right_shift)(long long int a, long long int b) { return ((unsigned long long int)a >> b); } __device__ inline long long int FN_INT64(trunc)(long long int a) { return a; } __device__ inline long long int FN_INT64(mod)(long long int a, long long int b) { long long int res = a % b; @@ -179,11 +183,8 @@ __device__ inline long long int FN_INT64(mod)(long long int a, long long int b) } __device__ inline long long int FN_INT64(pow)(long long int a, long long int b) { - long long int res = 1; - for (int i = 0; i < b; ++i) { - res *= a; - } - return res; + double res = pow(__ll2double_rd(a), __ll2double_rd(b)); + return __double2ll_rn(res); } // *************************************************************** // @@ -320,7 +321,7 @@ __device__ inline bfloat16 FN_BF16(pow)(bfloat16 a, bfloat16 b) { __device__ inline float16 FN_FP16(ceil)(float16 x) { return float16(hceil(x.to_half())); } __device__ inline float16 FN_FP16(floor)(float16 x) { return float16(hfloor(x.to_half())); } -__device__ inline float16 FN_FP16(round)(float16 x) { return float16(hrint(x.to_half())); } +__device__ inline float16 FN_FP16(round)(float16 x) { return float16(FN_FP32(round)(static_cast(x))); } __device__ inline float16 FN_FP16(trunc)(float16 x) { return float16(htrunc(x.to_half())); } __device__ inline float16 FN_FP16(sin)(float16 x) { return float16(hsin(x.to_half())); } diff --git a/cinn/runtime/cuda/cuda_intrinsics.cc b/cinn/runtime/cuda/cuda_intrinsics.cc index 753d66d831..88e48973b3 100644 --- a/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/cinn/runtime/cuda/cuda_intrinsics.cc @@ -56,6 +56,7 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(bitwise_and); REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(bitwise_or); REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(bitwise_xor); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8(logical_right_shift); #undef REGISTER_EXTERN_FUNC_2_IN_1_OUT_UINT8 @@ -74,6 +75,7 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(bitwise_and); REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(bitwise_or); REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(bitwise_xor); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8(logical_right_shift); #undef REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT8 @@ -92,6 +94,7 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(bitwise_and); REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(bitwise_or); REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(bitwise_xor); + REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16(logical_right_shift); #undef REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT16 @@ -227,7 +230,6 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_and) REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_or) REGISTER_EXTERN_FUNC_2_IN_1_INT32(bitwise_xor) - REGISTER_EXTERN_FUNC_2_IN_1_INT32(floor_divide) REGISTER_EXTERN_FUNC_2_IN_1_INT32(logical_right_shift) REGISTER_EXTERN_FUNC_2_IN_1_INT32(mod) @@ -236,10 +238,12 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { #define REGISTER_EXTERN_FUNC_2_IN_1_INT64(func__) \ REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(cinn_nvgpu_##func__##_int64, target, int64_t, int64_t, int64_t); + REGISTER_EXTERN_FUNC_2_IN_1_INT64(pow) REGISTER_EXTERN_FUNC_2_IN_1_INT64(bitwise_and) REGISTER_EXTERN_FUNC_2_IN_1_INT64(bitwise_or) REGISTER_EXTERN_FUNC_2_IN_1_INT64(bitwise_xor) REGISTER_EXTERN_FUNC_2_IN_1_INT64(mod) + REGISTER_EXTERN_FUNC_2_IN_1_INT64(logical_right_shift) #undef REGISTER_EXTERN_FUNC_2_IN_1_INT64 diff --git a/cmake/external/jitify.cmake b/cmake/external/jitify.cmake index 5868d5e14a..080b8b93ee 100644 --- a/cmake/external/jitify.cmake +++ b/cmake/external/jitify.cmake @@ -11,7 +11,7 @@ ExternalProject_Add( external_jitify ${EXTERNAL_PROJECT_LOG_ARGS} GIT_REPOSITORY "https://github.com/NVIDIA/jitify.git" - GIT_TAG master + GIT_TAG 57de649139c866eb83acacfe50c92ad7c6278776 PREFIX ${THIRD_PARTY_PATH}/jitify SOURCE_DIR ${JITIFY_SOURCE_PATH} CONFIGURE_COMMAND "" diff --git a/python/tests/fusion/test_reduce_cast.py b/python/tests/fusion/test_reduce_cast.py new file mode 100644 index 0000000000..c5c94abe44 --- /dev/null +++ b/python/tests/fusion/test_reduce_cast.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from fusion_test import FusionTest + + +class TestGroup1(FusionTest): + def init_input_data(self): + self.feed_data = {} + + def build_program(self, builder, target): + x = builder.fill_constant( + dtype="float32", shape=[4, 5, 20, 20], value=1.00000000) + y = builder.cast( + builder.reduce_sum(x, dim=[2], keep_dim=False), "float16") + + feed_list = [] + fetch_list = [y] + + return feed_list, fetch_list + + def test_check_results(self): + self.check_fusion_outputs(group_size=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/ops/op_test_helper.py b/python/tests/ops/op_test_helper.py index eb8c99889a..d5fd1935c8 100644 --- a/python/tests/ops/op_test_helper.py +++ b/python/tests/ops/op_test_helper.py @@ -35,6 +35,9 @@ class TestCaseHelper(): Helper class for constructing test cases. """ + def __init__(self): + self.custom_attrs_list = [] + def init_attrs(self): """ Initialize attributes for op @@ -51,6 +54,12 @@ def _flatten_tuple(self, cur_tuple): new_dict.append((k, v)) return dict(new_dict) + def _register_custom_attrs(self, custom_attrs): + """ + register custom attribute + """ + self.custom_attrs_list.append(custom_attrs) + def _init_cases(self): """ Generate all test cases @@ -59,7 +68,9 @@ def _init_cases(self): assert isinstance(self.dtypes, list) assert isinstance(self.attrs, list) self.all_cases = [] - all_lists = [self.inputs, self.dtypes, self.attrs] + all_lists = [ + self.inputs, self.dtypes, self.attrs, *self.custom_attrs_list + ] filtered_lists = filter(lambda x: len(x) > 0, all_lists) for case in itertools.product(*filtered_lists): self.all_cases.append(self._flatten_tuple(case)) diff --git a/python/tests/ops/test_add_op.py b/python/tests/ops/test_add_op.py index 03136c35a8..80ea1f0863 100644 --- a/python/tests/ops/test_add_op.py +++ b/python/tests/ops/test_add_op.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 CINN Authors. All Rights Reserved. +# Copyright (c) 2023 CINN Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,32 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import numpy as np -from op_test import OpTest, OpTestTool import paddle -import cinn -from cinn.frontend import * from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestElementwiseAddOp(OpTest): def setUp(self): - self.init_case() + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.prepare_inputs() - def init_case(self): - self.inputs = { - "x": np.random.random([32, 64]).astype("float32"), - "y": np.random.random([32, 64]).astype("float32"), - "dout": np.random.random((32, 64)).astype("float32") - } - self.axis = -1 + def prepare_inputs(self): + self.x_np = self.random( + shape=self.case["x_shape"], + dtype=self.case["x_dtype"], + low=-10, + high=10) + self.y_np = self.random( + shape=self.case["y_shape"], + dtype=self.case["y_dtype"], + low=-10, + high=10) + self.dout_np = self.random( + self.case["dout_shape"], dtype=self.case["dout_dtype"]) def build_paddle_program(self, target): - x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) - y = paddle.to_tensor(self.inputs["y"], stop_gradient=False) + x = paddle.to_tensor(self.x_np, stop_gradient=False) + y = paddle.to_tensor(self.y_np, stop_gradient=False) def get_unsqueeze_axis(x_rank, y_rank, axis): self.assertTrue( @@ -48,83 +52,206 @@ def get_unsqueeze_axis(x_rank, y_rank, axis): axis = axis if axis >= 0 else x_rank - y_rank unsqueeze_axis = np.arange(0, axis).tolist() + np.arange( axis + y_rank, x_rank).tolist() - return unsqueeze_axis unsqueeze_axis = get_unsqueeze_axis( - len(self.inputs["x"].shape), len(self.inputs["y"].shape), - self.axis) + len(x.shape), len(y.shape), self.case["axis"]) y_t = paddle.unsqueeze( y, axis=unsqueeze_axis) if len(unsqueeze_axis) > 0 else y out = paddle.add(x, y_t) self.paddle_outputs = [out] self.paddle_grads = self.get_paddle_grads([out], [x, y], - [self.inputs["dout"]]) + [self.dout_np]) def build_cinn_program(self, target): builder = NetBuilder("add") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") - y = builder.create_input(Float(32), self.inputs["y"].shape, "y") - out = builder.add(x, y, axis=self.axis) + x = builder.create_input( + self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"], + "x") + y = builder.create_input( + self.nptype2cinntype(self.case["y_dtype"]), self.case["y_shape"], + "y") + out = builder.add(x, y, axis=self.case["axis"]) dout = builder.create_input( - Float(32), self.inputs["dout"].shape, "dout") + self.nptype2cinntype(self.case["dout_dtype"]), + self.case["dout_shape"], "dout") x_grad, y_grad = builder.elementwise_add_grad( - dout, x, y, axis=self.axis) + dout, x, y, axis=self.case["axis"]) prog = builder.build() - res = self.get_cinn_output( - prog, target, [x, y, dout], - [self.inputs["x"], self.inputs["y"], self.inputs["dout"]], - [out, x_grad, y_grad]) + res = self.get_cinn_output(prog, target, [x, y, dout], + [self.x_np, self.y_np, self.dout_np], + [out, x_grad, y_grad]) self.cinn_outputs = [res[0]] self.cinn_grads = [res[1], res[2]] def test_check_results(self): - self.check_outputs_and_grads() - - -class TestAddCase1(TestElementwiseAddOp): - def init_case(self): - self.inputs = { - "x": np.random.random([8, 16, 32, 32]).astype("float32"), - "y": np.random.random([32, 32]).astype("float32"), - "dout": np.random.random((8, 16, 32, 32)).astype("float32") - } - self.axis = -1 - - -class TestAddCase2(TestElementwiseAddOp): - def init_case(self): - self.inputs = { - "x": np.random.random([8, 1, 32, 32]).astype("float32"), - "y": np.random.random([16, 32]).astype("float32"), - "dout": np.random.random((8, 16, 32, 32)).astype("float32") - } - self.axis = 1 + max_relative_error = self.case[ + "max_relative_error"] if "max_relative_error" in self.case else 1e-5 + self.check_outputs_and_grads(max_relative_error=max_relative_error) -class TestAddCase3(TestElementwiseAddOp): - def init_case(self): - self.inputs = { - "x": np.random.random([4, 16, 8, 32]).astype("float32"), - "y": np.random.random([4, 16]).astype("float32"), - "dout": np.random.random((4, 16, 8, 32)).astype("float32") - } - self.axis = 0 +class TestAddAll(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestElementwiseAddOpCase" + self.cls = TestElementwiseAddOp + self.inputs = [ + { + "x_shape": [1], + "y_shape": [1], + "dout_shape": [1], + "axis": 0, + }, + { + "x_shape": [1024], + "y_shape": [1024], + "dout_shape": [1024], + "axis": -1, + }, + { + "x_shape": [512, 256], + "y_shape": [512, 256], + "dout_shape": [512, 256], + "axis": 0, + }, + { + "x_shape": [128, 64, 32], + "y_shape": [128, 64, 32], + "dout_shape": [128, 64, 32], + "axis": -1, + }, + { + "x_shape": [16, 8, 4, 2], + "y_shape": [16, 8, 4, 2], + "dout_shape": [16, 8, 4, 2], + "axis": 0, + }, + { + "x_shape": [16, 8, 4, 2, 1], + "y_shape": [16, 8, 4, 2, 1], + "dout_shape": [16, 8, 4, 2, 1], + "axis": -1, + }, + ] + self.dtypes = [ + # TODO: paddle 2.3.1 unsupport int16 now, remove after ci paddle updated + # { + # "x_dtype": "int16", + # "y_dtype": "int16", + # "dout_dtype": "int16", + # }, + { + "x_dtype": "int32", + "y_dtype": "int32", + "dout_dtype": "int32", + }, + { + "x_dtype": "int64", + "y_dtype": "int64", + "dout_dtype": "int64", + }, + { + "x_dtype": "float16", + "y_dtype": "float16", + "dout_dtype": "float16", + "max_relative_error": 1e-3, + }, + { + "x_dtype": "float32", + "y_dtype": "float32", + "dout_dtype": "float32", + }, + { + "x_dtype": "float64", + "y_dtype": "float64", + "dout_dtype": "float64", + }, + ] + self.attrs = [] -class TestAddCase4(TestElementwiseAddOp): - def init_case(self): - self.inputs = { - "x": np.random.random([4, 16, 8, 32]).astype("float32"), - "y": np.random.random([1]).astype("float32"), - "dout": np.random.random((4, 16, 8, 32)).astype("float32") - } - self.axis = -1 +class TestAddAllWithBroadcast(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestElementwiseAddOpCase" + self.cls = TestElementwiseAddOp + self.inputs = [ + { + "x_shape": [1], + "y_shape": [1], + "dout_shape": [1], + "axis": 0, + }, + { + "x_shape": [1024], + "y_shape": [1], + "dout_shape": [1024], + "axis": -1, + }, + { + "x_shape": [512, 256], + "y_shape": [1, 1], + "dout_shape": [512, 256], + "axis": 0, + }, + { + "x_shape": [128, 64, 32], + "y_shape": [1, 1, 1], + "dout_shape": [128, 64, 32], + "axis": -1, + }, + { + "x_shape": [16, 8, 4, 2], + "y_shape": [1, 1, 1, 1], + "dout_shape": [16, 8, 4, 2], + "axis": 0, + }, + { + "x_shape": [16, 8, 4, 2, 1], + "y_shape": [1, 1, 1, 1, 1], + "dout_shape": [16, 8, 4, 2, 1], + "axis": -1, + }, + ] + self.dtypes = [ + # Todo: Reduce does in support int16 + # { + # "x_dtype": "int16", + # "y_dtype": "int16", + # "dout_dtype": "int16", + # }, + { + "x_dtype": "int32", + "y_dtype": "int32", + "dout_dtype": "int32", + }, + { + "x_dtype": "int64", + "y_dtype": "int64", + "dout_dtype": "int64", + }, + { + "x_dtype": "float16", + "y_dtype": "float16", + "dout_dtype": "float16", + "max_relative_error": 1e-3, + }, + { + "x_dtype": "float32", + "y_dtype": "float32", + "dout_dtype": "float32", + }, + { + "x_dtype": "float64", + "y_dtype": "float64", + "dout_dtype": "float64", + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestAddAll().run() + TestAddAllWithBroadcast().run() diff --git a/python/tests/ops/test_add_op_new.py b/python/tests/ops/test_add_op_new.py deleted file mode 100644 index 75a602d783..0000000000 --- a/python/tests/ops/test_add_op_new.py +++ /dev/null @@ -1,271 +0,0 @@ -# Copyright (c) 2023 CINN Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import numpy as np -from op_test import OpTest, OpTestTool -from op_test_helper import TestCaseHelper -import paddle -import cinn -from cinn.frontend import * -from cinn.common import * - - -@OpTestTool.skip_if(not is_compiled_with_cuda(), - "x86 test will be skipped due to timeout.") -class TestElementwiseAddOp(OpTest): - def setUp(self): - print(f"\nRunning {self.__class__.__name__}: {self.case}") - self.prepare_inputs() - - def prepare_inputs(self): - self.x_np = self.random( - shape=self.case["x_shape"], - dtype=self.case["x_dtype"], - low=-10, - high=10) - self.y_np = self.random( - shape=self.case["y_shape"], - dtype=self.case["y_dtype"], - low=-10, - high=10) - self.dout_np = self.random( - self.case["dout_shape"], dtype=self.case["dout_dtype"]) - - def build_paddle_program(self, target): - x = paddle.to_tensor(self.x_np, stop_gradient=False) - y = paddle.to_tensor(self.y_np, stop_gradient=False) - - def get_unsqueeze_axis(x_rank, y_rank, axis): - self.assertTrue( - x_rank >= y_rank, - "The rank of x should be greater or equal to that of y.") - axis = axis if axis >= 0 else x_rank - y_rank - unsqueeze_axis = np.arange(0, axis).tolist() + np.arange( - axis + y_rank, x_rank).tolist() - return unsqueeze_axis - - unsqueeze_axis = get_unsqueeze_axis( - len(x.shape), len(y.shape), self.case["axis"]) - y_t = paddle.unsqueeze( - y, axis=unsqueeze_axis) if len(unsqueeze_axis) > 0 else y - out = paddle.add(x, y_t) - - self.paddle_outputs = [out] - self.paddle_grads = self.get_paddle_grads([out], [x, y], - [self.dout_np]) - - def build_cinn_program(self, target): - builder = NetBuilder("add") - x = builder.create_input( - self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"], - "x") - y = builder.create_input( - self.nptype2cinntype(self.case["y_dtype"]), self.case["y_shape"], - "y") - out = builder.add(x, y, axis=self.case["axis"]) - - dout = builder.create_input( - self.nptype2cinntype(self.case["dout_dtype"]), - self.case["dout_shape"], "dout") - x_grad, y_grad = builder.elementwise_add_grad( - dout, x, y, axis=self.case["axis"]) - - prog = builder.build() - res = self.get_cinn_output(prog, target, [x, y, dout], - [self.x_np, self.y_np, self.dout_np], - [out, x_grad, y_grad]) - - self.cinn_outputs = [res[0]] - self.cinn_grads = [res[1], res[2]] - - def test_check_results(self): - max_relative_error = self.case[ - "max_relative_error"] if "max_relative_error" in self.case else 1e-5 - self.check_outputs_and_grads(max_relative_error=max_relative_error) - - -class TestAddAll(TestCaseHelper): - def init_attrs(self): - self.class_name = "TestElementwiseAddOpCase" - self.cls = TestElementwiseAddOp - self.inputs = [ - { - "x_shape": [1], - "y_shape": [1], - "dout_shape": [1], - "axis": 0, - }, - { - "x_shape": [1024], - "y_shape": [1024], - "dout_shape": [1024], - "axis": -1, - }, - { - "x_shape": [512, 256], - "y_shape": [512, 256], - "dout_shape": [512, 256], - "axis": 0, - }, - { - "x_shape": [128, 64, 32], - "y_shape": [128, 64, 32], - "dout_shape": [128, 64, 32], - "axis": -1, - }, - { - "x_shape": [16, 8, 4, 2], - "y_shape": [16, 8, 4, 2], - "dout_shape": [16, 8, 4, 2], - "axis": 0, - }, - { - "x_shape": [16, 8, 4, 2, 1], - "y_shape": [16, 8, 4, 2, 1], - "dout_shape": [16, 8, 4, 2, 1], - "axis": -1, - }, - ] - self.dtypes = [ - # TODO: paddle 2.3.1 unsupport int16 now, remove after ci paddle updated - # { - # "x_dtype": "int16", - # "y_dtype": "int16", - # "dout_dtype": "int16", - # }, - { - "x_dtype": "int32", - "y_dtype": "int32", - "dout_dtype": "int32", - }, - { - "x_dtype": "int64", - "y_dtype": "int64", - "dout_dtype": "int64", - }, - { - "x_dtype": "float16", - "y_dtype": "float16", - "dout_dtype": "float16", - "max_relative_error": 1e-3, - }, - { - "x_dtype": "float32", - "y_dtype": "float32", - "dout_dtype": "float32", - }, - { - "x_dtype": "float64", - "y_dtype": "float64", - "dout_dtype": "float64", - }, - { - "x_dtype": "bfloat16", - "y_dtype": "bfloat16", - "dout_dtype": "bfloat16", - "max_relative_error": 1e-2, - }, - ] - self.attrs = [] - - -class TestAddAllWithBroadcast(TestCaseHelper): - def init_attrs(self): - self.class_name = "TestElementwiseAddOpCase" - self.cls = TestElementwiseAddOp - self.inputs = [ - { - "x_shape": [1], - "y_shape": [1], - "dout_shape": [1], - "axis": 0, - }, - { - "x_shape": [1024], - "y_shape": [1], - "dout_shape": [1024], - "axis": -1, - }, - { - "x_shape": [512, 256], - "y_shape": [1, 1], - "dout_shape": [512, 256], - "axis": 0, - }, - { - "x_shape": [128, 64, 32], - "y_shape": [1, 1, 1], - "dout_shape": [128, 64, 32], - "axis": -1, - }, - { - "x_shape": [16, 8, 4, 2], - "y_shape": [1, 1, 1, 1], - "dout_shape": [16, 8, 4, 2], - "axis": 0, - }, - { - "x_shape": [16, 8, 4, 2, 1], - "y_shape": [1, 1, 1, 1, 1], - "dout_shape": [16, 8, 4, 2, 1], - "axis": -1, - }, - ] - self.dtypes = [ - # Todo: Reduce does in support int16 - # { - # "x_dtype": "int16", - # "y_dtype": "int16", - # "dout_dtype": "int16", - # }, - { - "x_dtype": "int32", - "y_dtype": "int32", - "dout_dtype": "int32", - }, - { - "x_dtype": "int64", - "y_dtype": "int64", - "dout_dtype": "int64", - }, - { - "x_dtype": "float16", - "y_dtype": "float16", - "dout_dtype": "float16", - "max_relative_error": 1e-3, - }, - { - "x_dtype": "float32", - "y_dtype": "float32", - "dout_dtype": "float32", - }, - { - "x_dtype": "float64", - "y_dtype": "float64", - "dout_dtype": "float64", - }, - { - "x_dtype": "bfloat16", - "y_dtype": "bfloat16", - "dout_dtype": "bfloat16", - "max_relative_error": 1e-2, - }, - ] - self.attrs = [] - - -if __name__ == "__main__": - TestAddAll().run() - TestAddAllWithBroadcast().run() diff --git a/python/tests/ops/test_arange_op.py b/python/tests/ops/test_arange_op.py new file mode 100644 index 0000000000..2402400bfc --- /dev/null +++ b/python/tests/ops/test_arange_op.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from cinn.frontend import * +from cinn.common import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "x86 test will be skipped due to timeout.") +class TestArangeOp(OpTest): + def setUp(self): + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + self.inputs = { + "start": self.case["start"], + "end": self.case["end"], + "step": self.case["step"], + "dtype": self.case["dtype"] + } + + def build_paddle_program(self, target): + out = paddle.arange(self.inputs["start"], self.inputs["end"], + self.inputs["step"], self.inputs["dtype"]) + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("arange") + out = builder.arange(self.inputs["start"], self.inputs["end"], + self.inputs["step"], self.inputs["dtype"]) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [], [], [out]) + + self.cinn_outputs = res + + def test_check_results(self): + self.check_outputs_and_grads(all_equal=True) + + +class TestArangeOpShapeAndAttr(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestArangeOpShapeAndAttr" + self.cls = TestArangeOp + self.inputs = [ + # basic shape test + { + "start": 0, + "end": 10, + "step": 1, + }, + { + "start": 0, + "end": 1024, + "step": 16, + }, + { + "start": 512, + "end": 2600, + "step": 512, + }, + { + "start": 0, + "end": 65536, + "step": 1024, + }, + { + "start": 0, + "end": 131072, + "step": 2048, + }, + { + "start": 0, + "end": 1, + "step": 2, + }, + { + "start": 0, + "end": 1, + "step": 2, + }, + # step test + { + "start": 1024, + "end": 512, + "step": -2, + }, + { + "start": 2048, + "end": 0, + "step": -64, + }, + # range test + { + "start": -2048, + "end": 2048, + "step": 32, + }, + { + "start": -2048, + "end": -512, + "step": 64, + }, + { + "start": 1024, + "end": 4096, + "step": 512, + }, + { + "start": 1024, + "end": -1024, + "step": -128, + }, + { + "start": -1024, + "end": -2048, + "step": -64, + }, + { + "start": 2048, + "end": 512, + "step": -32, + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] + + +class TestArangeOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestArangeOpDtype" + self.cls = TestArangeOp + self.inputs = [ + { + "start": 5, + "end": 10, + "step": 1, + }, + { + "start": -10, + "end": -100, + "step": -10, + }, + { + "start": -10, + "end": 10, + "step": 1, + }, + ] + self.dtypes = [ + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [] + + +if __name__ == "__main__": + TestArangeOpShapeAndAttr().run() + TestArangeOpDtype().run() diff --git a/python/tests/ops/test_cbrt_op.py b/python/tests/ops/test_cbrt_op.py index 14a6385ac9..1ca112cdf9 100644 --- a/python/tests/ops/test_cbrt_op.py +++ b/python/tests/ops/test_cbrt_op.py @@ -14,26 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import paddle.nn.functional as F -import cinn -from cinn.frontend import * +import numpy as np from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestCbrtOp(OpTest): def setUp(self): - self.init_case() + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() - def init_case(self): + def prepare_inputs(self): self.inputs = { - "x": np.array([0, 1, 0.01, 27, 1000000, - 0.970299]).astype("float32") + "x": + self.random(self.case["shape"], self.case["dtype"], -100.0, 100.0), } def build_paddle_program(self, target): @@ -43,44 +43,101 @@ def build_paddle_program(self, target): def build_cinn_program(self, target): builder = NetBuilder("cbrt") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") + x = builder.create_input( + self.nptype2cinntype(self.inputs["x"].dtype), + self.inputs["x"].shape, "x") out = builder.cbrt(x) prog = builder.build() res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out]) - self.cinn_outputs = [res[0]] + self.cinn_outputs = res def test_check_results(self): - self.check_outputs_and_grads() - - -class TestCbrtCase1(TestCbrtOp): - def init_case(self): - self.inputs = { - "x": - np.array([0, 1, 0.01, 27, 1000000, 0.970299, 124483, - 13.7396]).astype("float32") - } - - -class TestCbrtCase2(TestCbrtOp): - def init_case(self): - self.inputs = { - "x": - np.array([[0, 1, 0.01, 27], [1000000, 0.970299, 124483, - 13.7396]]).astype("float32"), - } - - -class TestCbrtCase3(TestCbrtOp): - def init_case(self): - np.random.seed(0) - self.inputs = { - "x": np.random.random((32, 64)).astype("float32"), - } + self.check_outputs_and_grads( + max_relative_error=1e-3, max_absolute_error=1e-3) + + +class TestCbrtOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCbrtOpShape" + self.cls = TestCbrtOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [80, 40, 5, 7], + }, + { + "shape": [80, 1, 5, 7], + }, + { + "shape": [80, 3, 1024, 7], + }, + { + "shape": [10, 5, 1024, 2048], + }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] + + +class TestCbrtOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCbrtOpDtype" + self.cls = TestCbrtOp + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [5], + }, + { + "shape": [80, 40, 5, 7], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestCbrtOpShape().run() + TestCbrtOpDtype().run() diff --git a/python/tests/ops/test_ceil_op.py b/python/tests/ops/test_ceil_op.py index bf5a8189b1..1377849daf 100644 --- a/python/tests/ops/test_ceil_op.py +++ b/python/tests/ops/test_ceil_op.py @@ -14,27 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import cinn -from cinn.frontend import * from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestCeilOp(OpTest): def setUp(self): - self.init_case() + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() - def init_case(self): + def prepare_inputs(self): self.inputs = { - "x": np.random.random([ - 32, - 64, - ]).astype("float32") * 2 - 1 + "x": + self.random(self.case["shape"], self.case["dtype"], -100.0, 100.0), } def build_paddle_program(self, target): @@ -47,7 +45,9 @@ def build_paddle_program(self, target): # the forward result will be incorrect. def build_cinn_program(self, target): builder = NetBuilder("ceil") - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") + x = builder.create_input( + self.nptype2cinntype(self.inputs["x"].dtype), + self.inputs["x"].shape, "x") out = builder.ceil(x) prog = builder.build() @@ -60,12 +60,85 @@ def test_check_results(self): self.check_outputs_and_grads() -class TestCeilCase1(TestCeilOp): - def init_case(self): - self.inputs = { - "x": np.random.random([10201, 50]).astype("float32") * 3 - 1 - } +class TestCeilOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCeilOpShape" + self.cls = TestCeilOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [80, 40, 5, 7], + }, + { + "shape": [80, 1, 5, 7], + }, + { + "shape": [80, 3, 1024, 7], + }, + { + "shape": [10, 5, 1024, 2048], + }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] + + +class TestCeilOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCeilOpDtype" + self.cls = TestCeilOp + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [5], + }, + { + "shape": [80, 40, 5, 7], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestCeilOpShape().run() + TestCeilOpDtype().run() diff --git a/python/tests/ops/test_cholesky_op.py b/python/tests/ops/test_cholesky_op.py index 5ef0a9df93..d0396e0abb 100644 --- a/python/tests/ops/test_cholesky_op.py +++ b/python/tests/ops/test_cholesky_op.py @@ -14,27 +14,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import numpy as np -from op_test import OpTest, OpTestTool import paddle -import cinn -from cinn.frontend import * from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestCholeskyOp(OpTest): def setUp(self): - self.init_case() - - def init_case(self): - matrix = self.random([3, 3], "float32") - matrix_t = np.transpose(matrix, [1, 0]) - x = np.dot(matrix, matrix_t) + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + if "batch_dim" in self.case and self.case["batch_dim"] > 0: + x = [] + for _ in range(self.case["batch_dim"]): + matrix = self.random(self.case["shape"], self.case["dtype"], + -1.0, 1.0) + matrix_t = np.transpose(matrix, [1, 0]) + x.append(np.dot(matrix, matrix_t)) + x = np.stack(x) + else: + matrix = self.random(self.case["shape"], self.case["dtype"], -1.0, + 1.0) + matrix_t = np.transpose(matrix, [1, 0]) + x = np.dot(matrix, matrix_t) self.inputs = {"x": x} - self.upper = False + self.upper = self.case["upper"] def build_paddle_program(self, target): x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) @@ -56,34 +67,165 @@ def test_check_results(self): self.check_outputs_and_grads() -class TestCholeskyCase1(TestCholeskyOp): - def init_case(self): - matrix = self.random([5, 5], "float64") - matrix_t = np.transpose(matrix, [1, 0]) - x = np.dot(matrix, matrix_t) - self.inputs = {"x": x} - self.upper = True - - -class TestCholeskyCase2(TestCholeskyOp): - def init_case(self): - matrix = self.random([3, 3], "float32") - matrix_t = np.transpose(matrix, [1, 0]) - x = np.dot(matrix, matrix_t) - x = x * np.ones(shape=(3, 3, 3)) - self.inputs = {"x": x} - self.upper = False - - -class TestCholeskyCase3(TestCholeskyOp): - def init_case(self): - matrix = self.random([3, 3], "float64") - matrix_t = np.transpose(matrix, [1, 0]) - x = np.dot(matrix, matrix_t) - x = x * np.ones(shape=(2, 3, 3, 3)) - self.inputs = {"x": x} - self.upper = True +class TestCholeskyOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpShape" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1, 1], + }, + { + "shape": [8, 8], + }, + { + "shape": [10, 10], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "upper": False + }, + ] + + +class TestCholeskyOpLargeShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpLargeShape" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1024, 1024], + }, + { + "shape": [2048, 2048], + }, + ] + self.dtypes = [ + { + "dtype": "float64" + }, + ] + self.attrs = [ + { + "upper": False, + "batch_dim": 2 + }, + { + "upper": False, + "batch_dim": 4 + }, + { + "upper": True, + "batch_dim": 8 + }, + ] + + +class TestCholeskyOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpDtype" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1, 1], + }, + { + "shape": [8, 8], + }, + { + "shape": [10, 10], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [ + { + "upper": False + }, + ] + + +class TestCholeskyOpBatch(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpBatch" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1, 1], + }, + { + "shape": [8, 8], + }, + { + "shape": [10, 10], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "upper": False, + "batch_dim": 1 + }, + { + "upper": False, + "batch_dim": 4 + }, + { + "upper": False, + "batch_dim": 8 + }, + ] + + +class TestCholeskyOpAttrs(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestCholeskyOpAttrs" + self.cls = TestCholeskyOp + self.inputs = [ + { + "shape": [1, 1], + }, + { + "shape": [8, 8], + }, + { + "shape": [10, 10], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [ + { + "upper": True, + }, + ] if __name__ == "__main__": - unittest.main() + TestCholeskyOpShape().run() + TestCholeskyOpLargeShape().run() + TestCholeskyOpDtype().run() + TestCholeskyOpBatch().run() + TestCholeskyOpAttrs().run() diff --git a/python/tests/ops/test_concat_op.py b/python/tests/ops/test_concat_op.py index c1d68a45a6..23816e0774 100755 --- a/python/tests/ops/test_concat_op.py +++ b/python/tests/ops/test_concat_op.py @@ -14,27 +14,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import cinn -from cinn.frontend import * from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestConcatOp(OpTest): def setUp(self): - self.init_case() + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() - def init_case(self): - self.inputs = { - "x1": np.random.random([10201, 50]).astype("float32"), - "x2": np.random.random((10201, 50)).astype("float32") - } - self.axis = 0 + def prepare_inputs(self): + self.inputs = {} + self.axis = self.case["axis"] + dtype = self.case["dtype"] + shapes = self.case["shapes"] + for i, shape in enumerate(shapes): + name = "x" + str(i) + self.inputs[name] = self.random(shape, dtype) def paddle_inputs(self, inputs): return [ @@ -44,7 +46,8 @@ def paddle_inputs(self, inputs): def cinn_inputs(self, builder, inputs): return [ - builder.create_input(Float(32), data.shape, name) + builder.create_input( + self.nptype2cinntype(data.dtype), data.shape, name) for name, data in inputs.items() ] @@ -73,44 +76,287 @@ def test_check_results(self): self.check_outputs_and_grads(all_equal=True) -class TestConcatCase1(TestConcatOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([4, 3]).astype("float32"), - "x2": np.random.random((8, 3)).astype("float32") - } - self.axis = 0 +class TestConcatOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConcatOpShape" + self.cls = TestConcatOp + self.inputs = [ + { + "shapes": [[10], [6]], + }, + { + "shapes": [[8, 5], [8, 5]], + }, + { + "shapes": [[10, 3, 5], [4, 3, 5]], + }, + { + "shapes": [[80, 40, 5, 7], [20, 40, 5, 7]], + }, + { + "shapes": [[80, 1, 5, 7], [8, 1, 5, 7]], + }, + { + "shapes": [[80, 3, 1024, 7], [100, 3, 1024, 7]], + }, + { + "shapes": [[1, 5, 1024, 2048], [2, 5, 1024, 2048]], + }, + { + "shapes": [[1], [1]], + }, + { + "shapes": [[512], [512]], + }, + { + "shapes": [[1024], [512]], + }, + { + "shapes": [[2048], [4096]], + }, + { + "shapes": [[1, 1, 1, 1], [1, 1, 1, 1]], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "axis": 0 + }, + ] -class TestConcatCase2(TestConcatOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([2, 4, 8]).astype("float32"), - "x2": np.random.random((2, 4, 4)).astype("float32") - } - self.axis = -1 +class TestConcatOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConcatOpDtype" + self.cls = TestConcatOp + self.inputs = [ + { + "shapes": [[10], [6]], + }, + { + "shapes": [[8, 5], [8, 5]], + }, + { + "shapes": [[10, 3, 5], [4, 3, 5]], + }, + { + "shapes": [[80, 40, 5, 7], [20, 40, 5, 7]], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + { + "dtype": "bool" + }, + { + "dtype": "uint8" + }, + { + "dtype": "int8" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + ] + self.attrs = [ + { + "axis": 0 + }, + ] -class TestConcatCase3(TestConcatOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([2, 8, 4]).astype("float32"), - "x2": np.random.random((2, 4, 4)).astype("float32") - } - self.axis = 1 +class TestConcatOpMultipleInputs(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConcatOpMultipleInputs" + self.cls = TestConcatOp + self.inputs = [ + # 1D tensor with 1~4 inputs + { + "shapes": [[10]], + "axis": 0 + }, + { + "shapes": [[10], [6]], + "axis": 0 + }, + { + "shapes": [[10], [6], [8]], + "axis": 0 + }, + { + "shapes": [[10], [6], [10], [6]], + "axis": 0 + }, + # 2D tensor with 1~4 inputs + { + "shapes": [[8, 5]], + "axis": 1 + }, + { + "shapes": [[8, 5], [8, 8]], + "axis": 1 + }, + { + "shapes": [[8, 5], [8, 5], [16, 5]], + "axis": 0 + }, + { + "shapes": [[8, 5], [8, 5], [8, 5], [8, 5]], + "axis": 0 + }, + # 3D tensor with 1~4 inputs + { + "shapes": [[10, 3, 5]], + "axis": 0 + }, + { + "shapes": [[10, 3, 5], [10, 7, 5]], + "axis": 1 + }, + { + "shapes": [[10, 3, 5], [10, 3, 6], [10, 3, 7]], + "axis": 2 + }, + { + "shapes": [[10, 3, 5], [4, 3, 5], [2, 3, 5]], + "axis": 0 + }, + # 4D tensor with 1~4 inputs + { + "shapes": [[80, 1, 5, 7]], + "axis": 0 + }, + { + "shapes": [[80, 1, 5, 7], [80, 79, 5, 7]], + "axis": 1 + }, + { + "shapes": [[80, 1, 50, 7], [80, 1, 5, 7], [80, 1, 10, 7]], + "axis": 2 + }, + { + "shapes": [[80, 1, 5, 17], [80, 1, 5, 27], [80, 1, 5, 37], + [80, 1, 5, 47]], + "axis": + 3 + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] -class TestConcatCase5(TestConcatOp): - def init_case(self): - self.inputs = { - "x1": np.random.random([1, 16]).astype("float32"), - "x2": np.random.random([2, 16]).astype("float32"), - "x3": np.random.random([3, 16]).astype("float32"), - "x4": np.random.random([4, 16]).astype("float32"), - "x5": np.random.random([5, 16]).astype("float32") - } - self.axis = 0 +class TestConcatOpAttrs(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConcatOpAttrs" + self.cls = TestConcatOp + self.inputs = [ + # 1D tensor + { + "shapes": [[10], [8]], + "axis": 0 + }, + { + "shapes": [[10], [6]], + "axis": -1 + }, + # 2D tensor + { + "shapes": [[8, 5], [10, 5]], + "axis": 0 + }, + { + "shapes": [[8, 5], [8, 8]], + "axis": 1 + }, + # 3D tensor + { + "shapes": [[10, 3, 5], [10, 3, 5]], + "axis": 0 + }, + { + "shapes": [[10, 3, 5], [10, 7, 5]], + "axis": 1 + }, + { + "shapes": [[10, 3, 15], [10, 3, 5]], + "axis": 2 + }, + { + "shapes": [[10, 3, 7], [10, 3, 5]], + "axis": -1 + }, + { + "shapes": [[10, 3, 5], [10, 7, 5]], + "axis": -2 + }, + { + "shapes": [[10, 7, 5], [20, 7, 5]], + "axis": -3 + }, + # 4D tensor + { + "shapes": [[80, 1, 5, 7], [80, 1, 5, 7]], + "axis": 0 + }, + { + "shapes": [[80, 1, 5, 7], [80, 79, 5, 7]], + "axis": 1 + }, + { + "shapes": [[80, 1, 5, 7], [80, 1, 10, 7]], + "axis": 2 + }, + { + "shapes": [[80, 1, 5, 7], [80, 1, 5, 7]], + "axis": 3 + }, + { + "shapes": [[80, 1, 5, 7], [80, 1, 5, 13]], + "axis": -1 + }, + { + "shapes": [[80, 1, 5, 7], [80, 1, 5, 7]], + "axis": -2 + }, + { + "shapes": [[80, 15, 5, 7], [80, 5, 5, 7]], + "axis": -3 + }, + { + "shapes": [[80, 1, 5, 7], [20, 1, 5, 7]], + "axis": -4 + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestConcatOpShape().run() + TestConcatOpDtype().run() + TestConcatOpMultipleInputs().run() + TestConcatOpAttrs().run() diff --git a/python/tests/ops/test_constant_op.py b/python/tests/ops/test_constant_op.py index 6407d58512..424aa7c56a 100644 --- a/python/tests/ops/test_constant_op.py +++ b/python/tests/ops/test_constant_op.py @@ -14,99 +14,150 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import paddle.nn.functional as F -import cinn -from cinn.frontend import * from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestConstantOp(OpTest): def setUp(self): - self.init_case() - - def init_case(self): - self.value = 1.0 - self.name = 'x' - self.dtype = "float32" + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + self.name = "x" + dtype = self.case["dtype"] + if "constant_value" in self.case: + if "bool" in dtype: + self.value = bool(self.case["constant_value"]) + elif "int" in dtype: + self.value = int(self.case["constant_value"]) + elif "float" in dtype: + self.value = float(self.case["constant_value"]) + else: + self.value = self.random(self.case["shape"], dtype).tolist() + self.dtype = dtype def build_paddle_program(self, target): x = paddle.to_tensor(self.value, dtype=self.dtype) - self.paddle_outputs = [x] def build_cinn_program(self, target): builder = NetBuilder("constant") x = builder.constant(self.value, self.name, self.dtype) - prog = builder.build() res = self.get_cinn_output(prog, target, [], [], [x]) - self.cinn_outputs = [res[0]] + self.cinn_outputs = res def test_check_results(self): self.check_outputs_and_grads(all_equal=True) -class TestConstantCase1(TestConstantOp): - def init_case(self): - self.value = [1.0] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase2(TestConstantOp): - def init_case(self): - self.value = [1.0, 2.0, 3.0, 4.0, 5.0] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase3(TestConstantOp): - def init_case(self): - self.value = [[1.0, 2.0], [3.0, 4.0]] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase4(TestConstantOp): - def init_case(self): - self.value = [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase5(TestConstantOp): - def init_case(self): - self.value = [[[1.0], [3.0]], [[5.0], [7.0]]] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase6(TestConstantOp): - def init_case(self): - self.value = [[[1.0]]] - self.name = 'x' - self.dtype = "float32" - - -class TestConstantCase7(TestConstantOp): - def init_case(self): - self.value = self.random([200], "int32", 1, 1000).tolist() - self.name = 'x' - self.dtype = "int32" - - -class TestConstantCase8(TestConstantOp): - def init_case(self): - self.value = self.random([10], "int64", 1, 1000).tolist() - self.name = 'x' - self.dtype = "int64" +class TestConstantOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConstantOpShape" + self.cls = TestConstantOp + self.inputs = [ + { + "constant_value": 10, + }, + { + "constant_value": -5, + }, + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + # known issue: https://github.com/PaddlePaddle/CINN/pull/1453 + # The compilation time is particularly long for AssignValue op. + # { + # "shape": [16, 4, 8, 32], + # }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + # very slow for the shape 2048 + { + "shape": [2048], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [] + + +class TestConstantOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestConstantOpDtype" + self.cls = TestConstantOp + self.inputs = [ + { + "constant_value": 1, + }, + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + { + "dtype": "bool" + }, + { + "dtype": "uint8" + }, + { + "dtype": "int8" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestConstantOpShape().run() + TestConstantOpDtype().run() diff --git a/python/tests/ops/test_depthwise_conv2d_op.py b/python/tests/ops/test_depthwise_conv2d_op.py new file mode 100644 index 0000000000..60a5956dbc --- /dev/null +++ b/python/tests/ops/test_depthwise_conv2d_op.py @@ -0,0 +1,192 @@ +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper +import paddle +import paddle.nn as nn +import cinn +from cinn.frontend import * +from cinn.common import * + + +@OpTestTool.skip_if(not is_compiled_with_cudnn(), + "x86 test will be skipped due to timeout.") +class TestDepthwiseConv2dOp(OpTest): + def setUp(self): + # print(f"\n{self.__class__.__name__}: {self.case}") + self.prepare_inputs() + + def prepare_inputs(self): + self.x_np = self.random( + shape=self.case["x_shape"], dtype=self.case["dtype"]) + self.w_np = self.random( + shape=self.case["w_shape"], dtype=self.case["dtype"]) + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.x_np, stop_gradient=False) + weight = nn.initializer.Assign(self.w_np) + if self.case["data_format"] == "NCHW": + c_axis = 1 + elif self.case["data_format"] == "NHWC": + c_axis = 3 + else: + raise ValueError("Unknown data_format") + conv = nn.Conv2D( + in_channels=self.case["x_shape"][c_axis], + out_channels=self.case["x_shape"][c_axis], + kernel_size=self.case["kernel_size"], + stride=self.case["stride"], + padding=self.case["padding"], + dilation=self.case["dilation"], + groups=self.case["groups"], + weight_attr=weight, + bias_attr=False, + data_format=self.case["data_format"]) + y = conv(x) + self.paddle_outputs = [y] + + def build_cinn_program(self, target): + builder = NetBuilder("depthwise_conv2d") + x = builder.create_input( + self.nptype2cinntype(self.case["dtype"]), self.case["x_shape"], + "x") + weight = builder.create_input( + self.nptype2cinntype(self.case["dtype"]), self.case["w_shape"], + "weight") + + if self.case["data_format"] == "NCHW": + y = builder.depthwise_conv2d( + x, + weight, + strides=self.case["stride"], + paddings=self.case["padding"], + dilations=self.case["dilation"], + groups=self.case["groups"], + data_format=self.case["data_format"]) + elif self.case["data_format"] == "NHWC": + weight_t = builder.transpose(weight, [0, 2, 3, 1]) + y = builder.depthwise_conv2d( + x, + weight_t, + strides=self.case["stride"], + paddings=self.case["padding"], + dilations=self.case["dilation"], + groups=self.case["groups"], + data_format=self.case["data_format"]) + + prog = builder.build() + res = self.get_cinn_output( + prog, target, [x, weight], [self.x_np, self.w_np], [y], passes=[]) + self.cinn_outputs = res + + def test_check_results(self): + max_relative_error = self.case[ + "max_relative_error"] if "max_relative_error" in self.case else 1e-5 + self.check_outputs_and_grads(max_relative_error=max_relative_error) + + +class TestDepthwiseConv2dOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestDepthwiseConv2dCase" + self.cls = TestDepthwiseConv2dOp + self.inputs = [ + { + "x_shape": [3, 16, 32, 32], + "w_shape": [16, 1, 3, 3], + "data_format": "NCHW", + "groups": 16, + }, + { + "x_shape": [3, 16, 64, 64], + "w_shape": [16, 1, 3, 3], + "data_format": "NCHW", + "groups": 16, + }, + { + "x_shape": [3, 32, 32, 16], + "w_shape": [16, 1, 3, 3], + "data_format": "NHWC", + "groups": 16, + }, + { + "x_shape": [3, 64, 64, 16], + "w_shape": [16, 1, 3, 3], + "data_format": "NHWC", + "groups": 16, + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + ] + self.attrs = [ + { + "kernel_size": [3, 3], + "stride": [1, 1], + "padding": [0, 0], + "dilation": [1, 1], + }, + ] + + +class TestDepthwiseConv2dOpAttr(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestDepthwiseConv2dCase" + self.cls = TestDepthwiseConv2dOp + self.inputs = [ + { + "x_shape": [3, 16, 32, 32], + "w_shape": [16, 1, 3, 3], + "data_format": "NCHW", + "groups": 16, + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + ] + self.attrs = [ + { + "kernel_size": [5, 5], + "stride": [1, 1], + "padding": [0, 0], + "dilation": [1, 1], + }, + { + "kernel_size": [3, 3], + "stride": [2, 2], + "padding": [0, 0], + "dilation": [1, 1], + }, + { + "kernel_size": [3, 3], + "stride": [1, 1], + "padding": [1, 1], + "dilation": [1, 1], + }, + { + "kernel_size": [3, 3], + "stride": [1, 1], + "padding": [0, 0], + "dilation": [2, 2], + }, + ] + + +if __name__ == "__main__": + TestDepthwiseConv2dOpShape().run() + TestDepthwiseConv2dOpAttr().run() diff --git a/python/tests/ops/test_fill_constant_op.py b/python/tests/ops/test_fill_constant_op.py index a42fdea4b9..8f8a1bcb4b 100644 --- a/python/tests/ops/test_fill_constant_op.py +++ b/python/tests/ops/test_fill_constant_op.py @@ -14,145 +14,246 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np -from op_test import OpTest, OpTestTool import paddle -import paddle.nn.functional as F -import cinn -from cinn.frontend import * +import numpy as np from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestFillConstantOp(OpTest): def setUp(self): - self.init_case() - - def init_case(self): - self.shape = [32] - self.value = 1.0 - self.dtype = "float32" + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + self.shape = self.case["shape"] + self.value = self.case["value"] + self.dtype = self.case["dtype"] + if isinstance(self.value, str): + dtypes = ["bool", "int", "float"] + for dtype in dtypes: + if dtype in self.dtype: + try: + self.value = eval(f"{dtype}(self.value)") + except: + self.value = eval(f"{dtype}(0)") def build_paddle_program(self, target): - x = paddle.full(self.shape, self.value, dtype=self.dtype) + if self.dtype == None: + x = np.full(self.shape, self.value) + x = paddle.to_tensor(x) + else: + x = paddle.full(self.shape, self.value, dtype=self.dtype) self.paddle_outputs = [x] def build_cinn_program(self, target): builder = NetBuilder("fill_constant") - x = builder.fill_constant(self.shape, self.value, "out", self.dtype) + if self.dtype == None: + x = builder.fill_constant(self.shape, self.value, "out") + else: + x = builder.fill_constant(self.shape, self.value, "out", + self.dtype) prog = builder.build() res = self.get_cinn_output(prog, target, [], [], [x]) - self.cinn_outputs = [res[0]] + self.cinn_outputs = res def test_check_results(self): self.check_outputs_and_grads(all_equal=True) -class TestFillConstantCase1(TestFillConstantOp): - def init_case(self): - self.shape = [10, 32, 4] - self.value = 1.0 - self.dtype = "float32" - - -class TestFillConstantCase2(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = 1 - self.dtype = "int32" - - -class TestFillConstantCase3(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = True - self.dtype = "bool" - - -class TestFillConstantCase4(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = int(1) - self.dtype = "uint8" - - -class TestFillConstantCase5(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = int(1) - self.dtype = "int16" - - -class TestFillConstantStringValue(TestFillConstantOp): - def init_case(self): - self.shape = [32] - self.value = "0.12345678987654321" - self.dtype = "float64" - - -class TestFillConstantStringValueCase1(TestFillConstantStringValue): - def init_case(self): - self.shape = [32] - self.value = "0.12345678987654321" - self.dtype = "float16" - - -class TestFillConstantStringValueCase2(TestFillConstantStringValue): - def init_case(self): - self.shape = [32] - self.value = "123456789" - self.dtype = "int64" - - -@OpTestTool.skip_if(not is_compiled_with_cuda(), - "x86 test will be skipped due to timeout.") -class TestFillConstantByValueOp(OpTest): - def setUp(self): - self.init_case() - - def init_case(self): - self.shape = [32] - self.value = float(1.0) - self.dtype = "float32" - - def build_paddle_program(self, target): - x = paddle.full(self.shape, self.value, dtype=self.dtype) - - self.paddle_outputs = [x] - - def build_cinn_program(self, target): - builder = NetBuilder("fill_constant") - x = builder.fill_constant(self.shape, self.value, "out") - - prog = builder.build() - res = self.get_cinn_output(prog, target, [], [], [x]) - - self.cinn_outputs = [res[0]] - - def test_check_results(self): - self.check_outputs_and_grads(all_equal=True) - - -class TestFillConstantByValueCase1(TestFillConstantByValueOp): - def init_case(self): - self.shape = [32] - self.value = int(1) - # only for paddle.full - self.dtype = "int32" - - -class TestFillConstantByValueCase2(TestFillConstantByValueOp): - def init_case(self): - self.shape = [32] - self.value = bool(True) - # only for paddle.full - self.dtype = "bool" +class TestFillConstantOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestFillConstantOpShape" + self.cls = TestFillConstantOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + { + "shape": [16, 4, 8, 32], + }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "value": 123.456 + }, + ] + + +class TestFillConstantOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestFillConstantOpDtype" + self.cls = TestFillConstantOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + { + "dtype": "bool" + }, + { + "dtype": "uint8" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + ] + self.attrs = [ + { + "value": 123.456 + }, + ] + + +class TestFillConstantOpValue(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestFillConstantOpValue" + self.cls = TestFillConstantOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + ] + self.dtypes = [ + { + "dtype": None + }, + ] + self.attrs = [ + { + "value": bool(True) + }, + { + "value": int(123) + }, + { + "value": float(123.456) + }, + ] + + +class TestFillConstantOpStrValue(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestFillConstantOpStrValue" + self.cls = TestFillConstantOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [1, 2, 4, 8], + }, + ] + self.dtypes = [ + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + { + "dtype": "bool" + }, + { + "dtype": "uint8" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + ] + self.attrs = [ + { + "value": "1024" + }, + { + "value": "0.12345678987654321" + }, + ] if __name__ == "__main__": - unittest.main() + TestFillConstantOpShape().run() + TestFillConstantOpDtype().run() + TestFillConstantOpValue().run() + TestFillConstantOpStrValue().run() diff --git a/python/tests/ops/test_floor_divide_op.py b/python/tests/ops/test_floor_divide_op.py index b5f89a6580..996262fab4 100644 --- a/python/tests/ops/test_floor_divide_op.py +++ b/python/tests/ops/test_floor_divide_op.py @@ -36,13 +36,13 @@ def init_case(self): self.x_np = self.random( shape=self.case["x_shape"], dtype=self.case["x_dtype"], - low=-10, - high=10) + low=self.case["x_low"], + high=self.case["x_high"]) self.y_np = self.random( shape=self.case["y_shape"], dtype=self.case["y_dtype"], - low=1, - high=10) + low=self.case["y_low"], + high=self.case["y_high"]) def build_paddle_program(self, target): x = paddle.to_tensor(self.x_np, stop_gradient=True) @@ -66,7 +66,7 @@ def build_cinn_program(self, target): res = self.get_cinn_output(prog, target, [x, y], [self.x_np, self.y_np], [out]) - self.cinn_outputs = [res[0]] + self.cinn_outputs = res def test_check_results(self): max_relative_error = self.case[ @@ -74,7 +74,7 @@ def test_check_results(self): self.check_outputs_and_grads(max_relative_error=max_relative_error) -class TestFloorDivideAll(TestCaseHelper): +class TestFloorDivideShape(TestCaseHelper): def init_attrs(self): self.class_name = "TestFloorDivideOpCase" self.cls = TestFloorDivideOp @@ -109,18 +109,26 @@ def init_attrs(self): "x_dtype": "int32", "y_dtype": "int32", }, + ] + self.attrs = [ { - "x_dtype": "int64", - "y_dtype": "int64", + "x_low": -10, + "x_high": 10, + "y_low": -10, + "y_high": -1, + }, + { + "x_low": -10, + "x_high": 10, + "y_low": 1, + "y_high": 10, }, ] - self.attrs = [] -class TestFloorDivideAllWithBroadcast(TestCaseHelper): +class TestFloorDivideBroadcast(TestFloorDivideShape): def init_attrs(self): - self.class_name = "TestFloorDivideOpCase" - self.cls = TestFloorDivideOp + super().init_attrs() self.inputs = [ { "x_shape": [1], @@ -147,97 +155,26 @@ def init_attrs(self): "y_shape": [1, 1, 1, 1, 1], }, ] - self.dtypes = [ - { - "x_dtype": "int32", - "y_dtype": "int32", - }, - { - "x_dtype": "int64", - "y_dtype": "int64", - }, - ] - self.attrs = [] - - -class TestFloorDivideNegOp(OpTest): - def setUp(self): - print(f"\nRunning {self.__class__.__name__}: {self.case}") - self.init_case() - def init_case(self): - self.x_np = self.random( - shape=self.case["x_shape"], - dtype=self.case["x_dtype"], - low=-10, - high=10) - self.y_np = self.random( - shape=self.case["y_shape"], - dtype=self.case["y_dtype"], - low=-10, - high=-1) - - def build_paddle_program(self, target): - x = paddle.to_tensor(self.x_np, stop_gradient=True) - y = paddle.to_tensor(self.y_np, stop_gradient=True) - - out = paddle.floor_divide(x, y) - - self.paddle_outputs = [out] - - def build_cinn_program(self, target): - builder = NetBuilder("pow") - x = builder.create_input( - self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"], - "x") - y = builder.create_input( - self.nptype2cinntype(self.case["y_dtype"]), self.case["y_shape"], - "y") - out = builder.floor_divide(x, y) - - prog = builder.build() - res = self.get_cinn_output(prog, target, [x, y], - [self.x_np, self.y_np], [out]) - - self.cinn_outputs = [res[0]] - - def test_check_results(self): - max_relative_error = self.case[ - "max_relative_error"] if "max_relative_error" in self.case else 1e-5 - self.check_outputs_and_grads(max_relative_error=max_relative_error) - -class TestFloorDivideNegAll(TestCaseHelper): +class TestFloorDivideDtype(TestFloorDivideShape): def init_attrs(self): - self.class_name = "TestFloorDivideNegOpCase" - self.cls = TestFloorDivideNegOp + super().init_attrs() self.inputs = [ - { - "x_shape": [1], - "y_shape": [1], - }, { "x_shape": [1024], "y_shape": [1024], }, + ] + self.dtypes = [ { - "x_shape": [512, 256], - "y_shape": [512, 256], - }, - { - "x_shape": [128, 64, 32], - "y_shape": [128, 64, 32], - }, - { - "x_shape": [16, 8, 4, 2], - "y_shape": [16, 8, 4, 2], + "x_dtype": "int8", + "y_dtype": "int8", }, { - "x_shape": [16, 8, 4, 2, 1], - "y_shape": [16, 8, 4, 2, 1], + "x_dtype": "int16", + "y_dtype": "int16", }, - ] - self.dtypes = [ { "x_dtype": "int32", "y_dtype": "int32", @@ -246,55 +183,50 @@ def init_attrs(self): "x_dtype": "int64", "y_dtype": "int64", }, - ] - self.attrs = [] - - -class TestFloorDivideNegAllWithBroadcast(TestCaseHelper): - def init_attrs(self): - self.class_name = "TestFloorDivideNegOpCase" - self.cls = TestFloorDivideNegOp - self.inputs = [ - { - "x_shape": [1], - "y_shape": [1], - }, { - "x_shape": [1024], - "y_shape": [1], - }, - { - "x_shape": [512, 256], - "y_shape": [1, 1], + "x_dtype": "float16", + "y_dtype": "float16", + "max_relative_error": 1, }, { - "x_shape": [128, 64, 32], - "y_shape": [1, 1, 1], + "x_dtype": "float32", + "y_dtype": "float32", }, { - "x_shape": [16, 8, 4, 2], - "y_shape": [1, 1, 1, 1], + "x_dtype": "float64", + "y_dtype": "float64", }, + ] + + +class TestFloorDivideUINT(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestFloorDivideOpCase" + self.cls = TestFloorDivideOp + self.inputs = [ { - "x_shape": [16, 8, 4, 2, 1], - "y_shape": [1, 1, 1, 1, 1], + "x_shape": [1024], + "y_shape": [1024], }, ] self.dtypes = [ { - "x_dtype": "int32", - "y_dtype": "int32", + "x_dtype": "uint8", + "y_dtype": "uint8", }, + ] + self.attrs = [ { - "x_dtype": "int64", - "y_dtype": "int64", + "x_low": 1, + "x_high": 10, + "y_low": 1, + "y_high": 10, }, ] - self.attrs = [] if __name__ == "__main__": - TestFloorDivideAll().run() - TestFloorDivideNegAll().run() - TestFloorDivideAllWithBroadcast().run() - TestFloorDivideNegAllWithBroadcast().run() + TestFloorDivideShape().run() + TestFloorDivideBroadcast().run() + TestFloorDivideDtype().run() + TestFloorDivideUINT().run() diff --git a/python/tests/ops/test_isclose_op.py b/python/tests/ops/test_isclose_op.py index 3f8803e9ee..4c2621f03b 100644 --- a/python/tests/ops/test_isclose_op.py +++ b/python/tests/ops/test_isclose_op.py @@ -14,100 +14,192 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import numpy as np from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper import paddle -import paddle.nn.functional as F import cinn from cinn.frontend import * from cinn.common import * -import sys @OpTestTool.skip_if(not is_compiled_with_cuda(), "x86 test will be skipped due to timeout.") class TestIsCloseOp(OpTest): def setUp(self): - self.init_case() - - def init_case(self): - self.inputs = {"x": self.random([16, 16], "float32")} - self.inputs['y'] = self.inputs["x"] - self.rtol = 1e-05 - self.atol = 1e-08 - self.equal_nan = False + # print(f"\n{self.__class__.__name__}: {self.case}") + self.prepare_inputs() + + def prepare_inputs(self): + if self.case["nan_as_input"]: + self.x_np = np.full(shape=self.case["shape"], fill_value=np.nan) + else: + self.x_np = self.random( + shape=self.case["shape"], dtype=self.case["dtype"]) + self.y_np = self.x_np + self.random( + shape=self.case["shape"], dtype=self.case["dtype"]) def build_paddle_program(self, target): - x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) - y = paddle.to_tensor(self.inputs["y"], stop_gradient=False) - + x = paddle.to_tensor(self.x_np, stop_gradient=False) + y = paddle.to_tensor(self.y_np, stop_gradient=False) shape = paddle.broadcast_shape(x.shape, y.shape) x = paddle.broadcast_to(x, shape) y = paddle.broadcast_to(y, shape) - - out = paddle.isclose(x, y, self.rtol, self.atol, self.equal_nan) + out = paddle.isclose(x, y, self.case["rtol"], self.case["atol"], + self.case["equal_nan"]) self.paddle_outputs = [out] def build_cinn_program(self, target): builder = NetBuilder("isclose") - - x = builder.create_input(Float(32), self.inputs["x"].shape, "x") - y = builder.create_input(Float(32), self.inputs["y"].shape, "y") - out = builder.isclose(x, y, self.rtol, self.atol, self.equal_nan) + x = builder.create_input( + self.nptype2cinntype(self.x_np.dtype), self.x_np.shape, "x") + y = builder.create_input( + self.nptype2cinntype(self.y_np.dtype), self.y_np.shape, "y") + out = builder.isclose(x, y, self.case["rtol"], self.case["atol"], + self.case["equal_nan"]) prog = builder.build() - forward_res = self.get_cinn_output( - prog, target, [x, y], [self.inputs["x"], self.inputs["y"]], [out]) - - self.cinn_outputs = forward_res + res = self.get_cinn_output(prog, target, [x, y], + [self.x_np, self.y_np], [out]) + self.cinn_outputs = res def test_check_results(self): - self.check_outputs_and_grads() - - -class TestIsCloseOpCase1(TestIsCloseOp): - def init_case(self): - self.inputs = { - "x": self.random([16, 16], "float32"), - "y": self.random([16, 16], "float32") - } - self.rtol = 1e-05 - self.atol = 1e-08 - self.equal_nan = False - - -class TestIsCloseOpCase2(TestIsCloseOp): - def init_case(self): - self.inputs = { - "x": np.array([np.nan] * 32).astype("float32"), - "y": self.random([32], "float32") - } - self.rtol = 1e-05 - self.atol = 1e-08 - self.equal_nan = False - - -class TestIsCloseOpCase3(TestIsCloseOp): - def init_case(self): - self.inputs = { - "x": np.array([np.nan] * 32).astype("float32"), - "y": np.array([np.nan] * 32).astype("float32") - } - self.rtol = 1e-05 - self.atol = 1e-08 - self.equal_nan = True - - -class TestIsCloseOpCase4(TestIsCloseOp): - def init_case(self): - self.inputs = { - "x": self.random([16, 16], "float32"), - "y": self.random([1], "float32") - } - self.rtol = 1e-05 - self.atol = 1e-08 - self.equal_nan = False + self.check_outputs_and_grads(all_equal=True) + + +class TestIsCloseShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestIsCloseOpCase" + self.cls = TestIsCloseOp + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [1024], + }, + { + "shape": [512, 256], + }, + { + "shape": [128, 64, 32], + }, + { + "shape": [16, 8, 4, 2], + }, + { + "shape": [16, 8, 4, 2, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + ] + self.attrs = [ + { + "rtol": 1e-5, + "atol": 1e-8, + "equal_nan": False, + "nan_as_input": False, + }, + ] + + +class TestIsCloseDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestIsCloseOpCase" + self.cls = TestIsCloseOp + self.inputs = [ + { + "shape": [1024], + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + { + "dtype": "float64", + }, + ] + self.attrs = [ + { + "rtol": 1e-5, + "atol": 1e-8, + "equal_nan": False, + "nan_as_input": False, + }, + ] + + +class TestIsCloseAttr(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestIsCloseOpCase" + self.cls = TestIsCloseOp + self.inputs = [ + { + "shape": [1024], + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + ] + self.attrs = [ + { + "rtol": 1e-3, + "atol": 1e-3, + "equal_nan": False, + "nan_as_input": False, + }, + { + "rtol": 1e-5, + "atol": 1e-5, + "equal_nan": False, + "nan_as_input": False, + }, + { + "rtol": 1e-8, + "atol": 1e-8, + "equal_nan": False, + "nan_as_input": False, + }, + { + "rtol": 1e-5, + "atol": 1e-8, + "equal_nan": True, + "nan_as_input": False, + }, + ] + + +class TestIsCloseNAN(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestIsCloseOpCase" + self.cls = TestIsCloseOp + self.inputs = [ + { + "shape": [1024], + }, + ] + self.dtypes = [ + { + "dtype": "float64", + }, + ] + self.attrs = [ + { + "rtol": 1e-5, + "atol": 1e-8, + "equal_nan": True, + "nan_as_input": True, + }, + ] if __name__ == "__main__": - unittest.main() + TestIsCloseShape().run() + TestIsCloseDtype().run() + TestIsCloseAttr().run() + TestIsCloseNAN().run() diff --git a/python/tests/ops/test_log_op.py b/python/tests/ops/test_log_op.py new file mode 100644 index 0000000000..20d7900543 --- /dev/null +++ b/python/tests/ops/test_log_op.py @@ -0,0 +1,145 @@ +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest +from op_test_helper import TestCaseHelper +import paddle +import cinn +from cinn.frontend import * +from cinn.common import * + + +class TestLogOp(OpTest): + def setUp(self): + # print(f"\n{self.__class__.__name__}: {self.case}") + self.prepare_inputs() + + def prepare_inputs(self): + self.x_np = self.random( + shape=self.case["shape"], dtype=self.case["dtype"]) + self.base = self.case["base"] + + def paddle_op(self, x): + if self.base == "e": + return paddle.log(x) + elif self.base == "2": + return paddle.log2(x) + elif self.base == "10": + return paddle.log10(x) + else: + raise ValueError("Unknown log base") + + def cinn_op(self, builder, x): + if self.base == "e": + return builder.log(x) + elif self.base == "2": + return builder.log2(x) + elif self.base == "10": + return builder.log10(x) + else: + raise ValueError("Unknown log base") + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.x_np, stop_gradient=False) + out = self.paddle_op(x) + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("add") + x = builder.create_input( + self.nptype2cinntype(self.x_np.dtype), self.x_np.shape, "x") + out = self.cinn_op(builder, x) + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.x_np], [out]) + self.cinn_outputs = res + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestLogOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestLogeOpCase" + self.cls = TestLogOp + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [1024], + }, + { + "shape": [512, 256], + }, + { + "shape": [128, 64, 32], + }, + { + "shape": [16, 8, 4, 2], + }, + { + "shape": [16, 8, 4, 2, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + ] + self.attrs = [ + { + "base": "e", + }, + { + "base": "2", + }, + { + "base": "10", + }, + ] + + +class TestLogOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestLogeOpCase" + self.cls = TestLogOp + self.inputs = [ + { + "shape": [1024], + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + { + "dtype": "float64", + }, + ] + self.attrs = [ + { + "base": "e", + }, + { + "base": "2", + }, + { + "base": "10", + }, + ] + + +if __name__ == "__main__": + TestLogOpShape().run() + TestLogOpDtype().run() diff --git a/python/tests/ops/test_logical_right_shift_op.py b/python/tests/ops/test_logical_right_shift_op.py index 5d88ed34e4..ecec017cf1 100644 --- a/python/tests/ops/test_logical_right_shift_op.py +++ b/python/tests/ops/test_logical_right_shift_op.py @@ -14,11 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import numpy as np from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper import paddle -import paddle.nn.functional as F import cinn from cinn.frontend import * from cinn.common import * @@ -28,59 +27,104 @@ "x86 test will be skipped due to timeout.") class TestLogicalRightShift(OpTest): def setUp(self): - self.init_case() + # print(f"\n{self.__class__.__name__}: {self.case}") + self.prepare_inputs() - def init_case(self): - self.inputs = { - # "x": self.random([1, 24], 'int32', low = -2147483648, high=2147483647) - "x": - np.array([[ - 1690476611, 142184466, -1752569340, 1860589058, -1295695292, - 1912939056, -1416770533, -483282486, 284237925, -2094465968, - -823026780, -1503970769, -535860601, 1515033359, -1212100470, - -2008734407, 704803066, 1861454881, -479224831, 1939718614, - -1903975007, -1197706543, 1327016838, -232019105 - ]]).astype(np.int32), - # "y": self.random([1, 24], 'int32', low = 0, high=32) - "y": - np.array([[ - 20, 3, 12, 3, 0, 31, 0, 2, 6, 16, 1, 7, 6, 2, 19, 16, 7, 17, - 10, 15, 8, 9, 24, 4 - ]]).astype(np.int32) - } - self.outputs = { - "out": - np.array([[ - 1612, 17773058, 620702, 232573632, -1295695292, 0, -1416770533, - 952921202, 4441217, 33576, 1735970258, 21804660, 58736042, - 378758339, 5880, 34885, 5506273, 14201, 3726311, 59195, - 9339813, 6049337, 79, 253934261 - ]]).astype(np.int32) - } + def prepare_inputs(self): + iinfo = np.iinfo(self.case["dtype"]) + self.x_np = self.random( + shape=self.case["shape"], + dtype=self.case["dtype"], + low=0, + high=iinfo.max) + self.y_np = self.random( + shape=self.case["shape"], + dtype=self.case["dtype"], + low=0, + high=iinfo.bits) def build_paddle_program(self, target): - out = paddle.to_tensor(self.outputs["out"], stop_gradient=False) + out_np = np.right_shift(self.x_np, self.y_np) + out = paddle.to_tensor(out_np, stop_gradient=True) self.paddle_outputs = [out] def build_cinn_program(self, target): builder = NetBuilder("logical_right_shift") x = builder.create_input( - self.nptype2cinntype(self.inputs["x"].dtype), - self.inputs["x"].shape, "x") + self.nptype2cinntype(self.x_np.dtype), self.x_np.shape, "x") y = builder.create_input( - self.nptype2cinntype(self.inputs["y"].dtype), - self.inputs["y"].shape, "y") + self.nptype2cinntype(self.y_np.dtype), self.y_np.shape, "y") out = builder.logical_right_shift(x, y) - prog = builder.build() res = self.get_cinn_output(prog, target, [x, y], - [self.inputs["x"], self.inputs["y"]], [out]) - - self.cinn_outputs = [res[0]] + [self.x_np, self.y_np], [out]) + self.cinn_outputs = res def test_check_results(self): - self.check_outputs_and_grads() + self.check_outputs_and_grads(all_equal=True) + + +class TestLogicalRightShiftShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestLogicalRightShiftCase" + self.cls = TestLogicalRightShift + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [1024], + }, + { + "shape": [512, 256], + }, + { + "shape": [128, 64, 32], + }, + { + "shape": [16, 8, 4, 2], + }, + { + "shape": [16, 8, 4, 2, 1], + }, + ] + self.dtypes = [ + { + "dtype": "int32", + }, + ] + self.attrs = [] + + +class TestLogicalRightShiftDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestLogicalRightShiftCase" + self.cls = TestLogicalRightShift + self.inputs = [ + { + "shape": [1024], + }, + ] + self.dtypes = [ + { + "dtype": "uint8", + }, + { + "dtype": "int8", + }, + { + "dtype": "int16", + }, + { + "dtype": "int32", + }, + { + "dtype": "int64", + }, + ] + self.attrs = [] if __name__ == "__main__": - unittest.main() + TestLogicalRightShiftShape().run() + TestLogicalRightShiftDtype().run() diff --git a/python/tests/ops/test_lookup_table_op.py b/python/tests/ops/test_lookup_table_op.py index c526ef3f89..1ad2200935 100644 --- a/python/tests/ops/test_lookup_table_op.py +++ b/python/tests/ops/test_lookup_table_op.py @@ -14,9 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import numpy as np from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper import paddle import paddle.nn.functional as F import cinn @@ -28,46 +27,86 @@ "x86 test will be skipped due to timeout.") class TestLookupTableOp(OpTest): def setUp(self): - self.init_case() + # print(f"\n{self.__class__.__name__}: {self.case}") + self.prepare_inputs() - def init_case(self): - self.inputs = { - "table": np.random.random([ - 10, - 20, - ]).astype("float32"), - "ids": np.random.random_integers(0, 9, (5, 2)).astype("int64") - } + def prepare_inputs(self): + self.table_np = self.random( + shape=self.case["table_shape"], dtype=self.case["table_dtype"]) + self.ids_np = self.random( + shape=self.case["ids_shape"], + dtype=self.case["ids_dtype"], + low=0, + high=self.case["table_shape"][0]) def build_paddle_program(self, target): - table = paddle.to_tensor(self.inputs["table"], stop_gradient=False) - ids = paddle.to_tensor(self.inputs["ids"], stop_gradient=False) - out = F.embedding(ids, table, 1) - + table = paddle.to_tensor(self.table_np, stop_gradient=False) + ids = paddle.to_tensor(self.ids_np, stop_gradient=False) + out = F.embedding(ids, table, self.case["padding_idx"]) self.paddle_outputs = [out] - # Note: If the forward and backward operators are run in the same program, - # the forward result will be incorrect. def build_cinn_program(self, target): builder = NetBuilder("lookup_table") table = builder.create_input( - Float(32), self.inputs["table"].shape, "table") + self.nptype2cinntype(self.table_np.dtype), self.table_np.shape, + "table") ids = builder.create_input( - Int(64), self.inputs["ids"].shape + (1, ), "ids") - out = builder.lookup_table(table, ids, 1) + self.nptype2cinntype(self.ids_np.dtype), self.ids_np.shape + (1, ), + "ids") + out = builder.lookup_table(table, ids, self.case["padding_idx"]) prog = builder.build() - forward_res = self.get_cinn_output( - prog, target, [table, ids], - [self.inputs["table"], self.inputs["ids"]], [out]) - - self.cinn_outputs = forward_res + res = self.get_cinn_output(prog, target, [table, ids], + [self.table_np, self.ids_np], [out]) + self.cinn_outputs = res def test_check_results(self): - self.build_paddle_program(self.target) - self.build_cinn_program(self.target) - self.check_results(self.paddle_outputs, self.cinn_outputs, 1e-5, False, - False) + self.check_outputs_and_grads(all_equal=True) + + +class TestLookupTableOpAll(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestLookupTableOpCase" + self.cls = TestLookupTableOp + self.inputs = [ + { + "table_shape": [128, 8], + "ids_shape": [8], + }, + { + "table_shape": [256, 4], + "ids_shape": [8, 4], + }, + { + "table_shape": [1024, 2], + "ids_shape": [8, 4, 2], + }, + ] + self.dtypes = [ + { + "table_dtype": "float16", + "ids_dtype": "int16", + }, + { + "table_dtype": "float32", + "ids_dtype": "int32", + }, + { + "table_dtype": "float64", + "ids_dtype": "int64", + }, + ] + self.attrs = [ + { + "padding_idx": -1, + }, + { + "padding_idx": 0, + }, + { + "padding_idx": 1, + }, + ] if __name__ == "__main__": - unittest.main() + TestLookupTableOpAll().run() diff --git a/python/tests/ops/test_pow_op.py b/python/tests/ops/test_pow_op.py index 51133053c4..90be49863d 100644 --- a/python/tests/ops/test_pow_op.py +++ b/python/tests/ops/test_pow_op.py @@ -14,11 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import numpy as np from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper import paddle -import paddle.nn.functional as F import cinn from cinn.frontend import * from cinn.common import * @@ -28,69 +27,126 @@ "x86 test will be skipped due to timeout.") class TestPowOp(OpTest): def setUp(self): - self.init_case() - - def init_case(self): - self.inputs = { - "x": self.random([32, 64], "float32"), - "y": self.random([32, 64], "float32", 0.0, 4.0) - } - self.axis = -1 + # print(f"\n{self.__class__.__name__}: {self.case}") + self.prepare_inputs() + + def prepare_inputs(self): + self.x_np = self.random( + shape=self.case["x_shape"], + dtype=self.case["dtype"], + low=self.case["base_low"], + high=self.case["base_high"]) + self.y_np = self.random( + shape=self.case["y_shape"], + dtype=self.case["dtype"], + low=self.case["exp_low"], + high=self.case["exp_high"]) + self.axis = np.random.choice([-1, 0]) def build_paddle_program(self, target): - x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) - y = paddle.to_tensor(self.inputs["y"], stop_gradient=False) - + x = paddle.to_tensor(self.x_np, stop_gradient=False) + y = paddle.to_tensor(self.y_np, stop_gradient=False) out = paddle.pow(x, y) - self.paddle_outputs = [out] def build_cinn_program(self, target): builder = NetBuilder("pow") x = builder.create_input( - self.nptype2cinntype(self.inputs["x"].dtype), - self.inputs["x"].shape, "x") + self.nptype2cinntype(self.x_np.dtype), self.x_np.shape, "x") y = builder.create_input( - self.nptype2cinntype(self.inputs["y"].dtype), - self.inputs["y"].shape, "y") + self.nptype2cinntype(self.y_np.dtype), self.y_np.shape, "y") out = builder.pow(x, y, axis=self.axis) - prog = builder.build() res = self.get_cinn_output(prog, target, [x, y], - [self.inputs["x"], self.inputs["y"]], [out]) - - self.cinn_outputs = [res[0]] + [self.x_np, self.y_np], [out]) + self.cinn_outputs = res def test_check_results(self): - self.check_outputs_and_grads() - - -class TestPowCase1(TestPowOp): - def init_case(self): - self.inputs = { - "x": self.random([8, 16, 32, 32], "float32"), - "y": self.random([1], "float32", 0.0, 4.0) - } - self.axis = 0 - - -class TestPowCase2(TestPowOp): - def init_case(self): - self.inputs = { - "x": self.random([8, 16, 32, 32], "int32", 2, 10), - "y": self.random([8, 16, 32, 32], "int32", 0, 5) - } - self.axis = -1 - - -class TestPowFP64(TestPowOp): - def init_case(self): - self.inputs = { - "x": self.random([8, 16, 32, 32], "float64", 2, 10), - "y": self.random([8, 16, 32, 32], "float64", 0, 5) - } - self.axis = -1 + self.check_outputs_and_grads(equal_nan=True) + + +class TestPowOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestLogicalRightShiftCase" + self.cls = TestPowOp + self.inputs = [ + { + "x_shape": [1], + "y_shape": [1], + }, + { + "x_shape": [1024], + "y_shape": [1024], + }, + { + "x_shape": [512, 256], + "y_shape": [512, 256], + }, + { + "x_shape": [128, 64, 32], + "y_shape": [128, 64, 32], + }, + { + "x_shape": [16, 8, 4, 2], + "y_shape": [16, 8, 4, 2], + }, + { + "x_shape": [16, 8, 4, 2, 1], + "y_shape": [16, 8, 4, 2, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + ] + self.attrs = [ + { + "base_low": -10, + "base_high": 10, + "exp_low": -3, + "exp_high": 3, + }, + ] + + +class TestPowOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestLogicalRightShiftCase" + self.cls = TestPowOp + self.inputs = [ + { + "x_shape": [1024], + "y_shape": [1024], + }, + ] + self.dtypes = [ + { + "dtype": "int32", + }, + { + "dtype": "int64", + }, + { + "dtype": "float16", + }, + { + "dtype": "float32", + }, + { + "dtype": "float64", + }, + ] + self.attrs = [ + { + "base_low": -10, + "base_high": 10, + "exp_low": -3, + "exp_high": 3, + }, + ] if __name__ == "__main__": - unittest.main() + TestPowOpShape().run() + TestPowOpDtype().run() diff --git a/python/tests/ops/test_repeat_op.py b/python/tests/ops/test_repeat_op.py new file mode 100644 index 0000000000..3745054e5c --- /dev/null +++ b/python/tests/ops/test_repeat_op.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +from cinn.frontend import * +from cinn.common import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "x86 test will be skipped due to timeout.") +class TestRepeatOp(OpTest): + def setUp(self): + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + shape = self.case["shape"] + dtype = self.case["dtype"] + repeats = self.case["repeats"] + axis = self.case["axis"] + dims = len(shape) + axis = min(axis, dims - 1) + axis = max(axis, -dims) + self.inputs = { + "x": self.random(shape, dtype, -1.0, 1.0), + "repeats": repeats, + "axis": axis + } + + def build_paddle_program(self, target): + x = np.repeat(self.inputs["x"], self.inputs["repeats"], + self.inputs["axis"]) + out = paddle.to_tensor(x, stop_gradient=True) + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("repeat") + x = builder.create_input( + self.nptype2cinntype(self.inputs["x"].dtype), + self.inputs["x"].shape, "x") + out = builder.repeat(x, self.inputs["repeats"], self.inputs["axis"]) + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], + [out]) + + self.cinn_outputs = res + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestRepeatOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestRepeatOpShape" + self.cls = TestRepeatOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [80, 40, 5, 7], + }, + { + "shape": [80, 1, 5, 7], + }, + { + "shape": [80, 3, 1024, 7], + }, + { + "shape": [10, 5, 1024, 2048], + }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "repeats": 2, + "axis": 0 + }, + ] + + +class TestRepeatOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestRepeatOpDtype" + self.cls = TestRepeatOp + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [5], + }, + { + "shape": [80, 40, 5, 7], + }, + ] + self.dtypes = [ + { + "dtype": "bool" + }, + { + "dtype": "int8" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [ + { + "repeats": 4, + "axis": 0 + }, + ] + + +class TestRepeatOpAttributeRepeats(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestRepeatOpAttributeRepeats" + self.cls = TestRepeatOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [80, 40, 5, 7], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "repeats": 256, + "axis": 0 + }, + { + "repeats": 1024, + "axis": 0 + }, + { + "repeats": 2048, + "axis": 0 + }, + ] + + +class TestRepeatOpAttributeAxis(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestRepeatOpAttributeAxis" + self.cls = TestRepeatOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [80, 40, 5, 7], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "repeats": 128, + "axis": 0 + }, + { + "repeats": 128, + "axis": 1 + }, + { + "repeats": 128, + "axis": 2 + }, + { + "repeats": 128, + "axis": 3 + }, + { + "repeats": 128, + "axis": -1 + }, + { + "repeats": 128, + "axis": -2 + }, + { + "repeats": 128, + "axis": -3 + }, + { + "repeats": 128, + "axis": -4 + }, + ] + + +if __name__ == "__main__": + TestRepeatOpShape().run() + TestRepeatOpDtype().run() + TestRepeatOpAttributeRepeats().run() + TestRepeatOpAttributeAxis().run() diff --git a/python/tests/ops/test_reverse_op.py b/python/tests/ops/test_reverse_op.py new file mode 100755 index 0000000000..3bde72d323 --- /dev/null +++ b/python/tests/ops/test_reverse_op.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from cinn.common import * +from cinn.frontend import * +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "x86 test will be skipped due to timeout.") +class TestReverseOp(OpTest): + def setUp(self): + print(f"\nRunning {self.__class__.__name__}: {self.case}") + self.inputs = {} + self.prepare_inputs() + + def prepare_inputs(self): + dims = len(self.case["shape"]) + axes = self.case["axes"].copy() + for i in range(len(axes)): + axes[i] = min(axes[i], dims - 1) + axes[i] = max(axes[i], -dims) + self.inputs = { + "x": self.random(self.case["shape"], self.case["dtype"]), + "axes": axes + } + self.net_builder_api = self.case["net_builder_api"] + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.inputs["x"], stop_gradient=True) + if self.net_builder_api == "reverse": + out = paddle.reverse(x, self.inputs["axes"]) + elif self.net_builder_api == "flip": + out = paddle.flip(x, self.inputs["axes"]) + else: + raise NotImplementedError + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("reverse") + x = builder.create_input( + self.nptype2cinntype(self.inputs["x"].dtype), + self.inputs["x"].shape, "x") + if self.net_builder_api == "reverse": + out = builder.reverse(x, self.inputs["axes"]) + elif self.net_builder_api == "flip": + out = builder.flip(x, self.inputs["axes"]) + else: + raise NotImplementedError + + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], + [out]) + + self.cinn_outputs = res + + def test_check_results(self): + self.check_outputs_and_grads(all_equal=True) + + +class TestReverseOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestReverseOpShape" + self.cls = TestReverseOp + self.inputs = [ + { + "shape": [10], + }, + { + "shape": [8, 5], + }, + { + "shape": [10, 3, 5], + }, + { + "shape": [80, 40, 5, 7], + }, + { + "shape": [80, 1, 5, 7], + }, + { + "shape": [80, 3, 1024, 7], + }, + { + "shape": [10, 5, 1024, 2048], + }, + { + "shape": [1], + }, + { + "shape": [512], + }, + { + "shape": [1024], + }, + { + "shape": [2048], + }, + { + "shape": [65536], + }, + { + "shape": [131072], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "axes": [0] + }, + ] + net_builder_api_attrs = [ + { + "net_builder_api": "reverse", + }, + { + "net_builder_api": "flip", + }, + ] + self._register_custom_attrs(net_builder_api_attrs) + + +class TestReverseOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestReverseOpDtype" + self.cls = TestReverseOp + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [5, 10], + }, + { + "shape": [80, 40, 5, 7], + }, + ] + self.dtypes = [ + { + "dtype": "bool" + }, + { + "dtype": "int32" + }, + { + "dtype": "int64" + }, + { + "dtype": "float16" + }, + { + "dtype": "float32" + }, + { + "dtype": "float64" + }, + ] + self.attrs = [ + { + "axes": [0] + }, + ] + net_builder_api_attrs = [ + { + "net_builder_api": "reverse", + }, + { + "net_builder_api": "flip", + }, + ] + self._register_custom_attrs(net_builder_api_attrs) + + +class TestReverseOpAxis(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestReverseOpAxis" + self.cls = TestReverseOp + self.inputs = [ + { + "shape": [8, 4, 2, 16], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "axes": [0] + }, + { + "axes": [1] + }, + { + "axes": [2] + }, + { + "axes": [3] + }, + { + "axes": [-1] + }, + { + "axes": [-2] + }, + { + "axes": [-3] + }, + { + "axes": [-4] + }, + ] + net_builder_api_attrs = [ + { + "net_builder_api": "reverse", + }, + { + "net_builder_api": "flip", + }, + ] + self._register_custom_attrs(net_builder_api_attrs) + + +class TestReverseOpMultiAxis(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestReverseOpMultiAxis" + self.cls = TestReverseOp + self.inputs = [ + { + "shape": [8, 4, 2, 16], + }, + { + "shape": [1, 1, 1, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32" + }, + ] + self.attrs = [ + { + "axes": [] + }, + { + "axes": [0] + }, + { + "axes": [1, 2] + }, + { + "axes": [2, -1, 3] + }, + { + "axes": [0, -3, 3, 1] + }, + { + "axes": [-1] + }, + { + "axes": [-2, -1] + }, + { + "axes": [-3, -2, 3] + }, + { + "axes": [0, 3, -3, -2] + }, + ] + net_builder_api_attrs = [ + { + "net_builder_api": "reverse", + }, + { + "net_builder_api": "flip", + }, + ] + self._register_custom_attrs(net_builder_api_attrs) + + +if __name__ == "__main__": + TestReverseOpShape().run() + TestReverseOpDtype().run() + TestReverseOpAxis().run() + TestReverseOpMultiAxis().run() diff --git a/python/tests/ops/test_round_op.py b/python/tests/ops/test_round_op.py new file mode 100644 index 0000000000..7180ad8920 --- /dev/null +++ b/python/tests/ops/test_round_op.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_test import OpTest, OpTestTool +from op_test_helper import TestCaseHelper +import paddle +import cinn +from cinn.frontend import * +from cinn.common import * + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "x86 test will be skipped due to timeout.") +class TestRoundOp(OpTest): + def setUp(self): + # print(f"\n{self.__class__.__name__}: {self.case}") + self.prepare_inputs() + + def prepare_inputs(self): + self.x_np = self.random( + shape=self.case["shape"], dtype=self.case["dtype"]) + + def build_paddle_program(self, target): + x = paddle.to_tensor(self.x_np, stop_gradient=False) + out = paddle.round(x) + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("add") + x = builder.create_input( + self.nptype2cinntype(self.x_np.dtype), self.x_np.shape, "x") + out = builder.round(x) + prog = builder.build() + res = self.get_cinn_output(prog, target, [x], [self.x_np], [out]) + self.cinn_outputs = res + + def test_check_results(self): + self.check_outputs_and_grads() + + +class TestRoundOpShape(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestRoundOpCase" + self.cls = TestRoundOp + self.inputs = [ + { + "shape": [1], + }, + { + "shape": [1024], + }, + { + "shape": [512, 256], + }, + { + "shape": [128, 64, 32], + }, + { + "shape": [16, 8, 4, 2], + }, + { + "shape": [16, 8, 4, 2, 1], + }, + ] + self.dtypes = [ + { + "dtype": "float32", + }, + ] + self.attrs = [] + + +class TestRoundOpDtype(TestCaseHelper): + def init_attrs(self): + self.class_name = "TestRoundOpCase" + self.cls = TestRoundOp + self.inputs = [ + { + "shape": [1024], + }, + ] + self.dtypes = [ + { + "dtype": "float16", + }, + { + "dtype": "bfloat16", + }, + { + "dtype": "float32", + }, + { + "dtype": "float64", + }, + ] + self.attrs = [] + + +if __name__ == "__main__": + TestRoundOpShape().run() + TestRoundOpDtype().run() diff --git a/python/tests/ops/test_sign_op.py b/python/tests/ops/test_sign_op.py index 920cda2564..b70faaff2c 100644 --- a/python/tests/ops/test_sign_op.py +++ b/python/tests/ops/test_sign_op.py @@ -87,10 +87,10 @@ def init_attrs(self): "shape": [80, 1, 5, 7], }, { - "shape": [80, 3, 1024, 7], + "shape": [80, 3, 32, 7], }, { - "shape": [10, 5, 1024, 2048], + "shape": [10, 5, 32, 32], }, { "shape": [1],