Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
op unittest for repeat/arange/reverse/elementwise_add_grad/flip (#1514)
Browse files Browse the repository at this point in the history
* op unittest for repeat op

* add repeat frontend

* op unittest for repeat

* op unittest for arange

* op unittest for reverse

* format & remove old add op test

* op unittest for flipe && remove redundant flip implementation

* remove test_add_op_new.py

* update reverse
  • Loading branch information
zzk0 authored Jun 12, 2023
1 parent 4a8536a commit a659bb2
Show file tree
Hide file tree
Showing 16 changed files with 983 additions and 645 deletions.
6 changes: 1 addition & 5 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -816,11 +816,7 @@ Variable NetBuilder::Arange(const float start, const float stop, const float ste
}

Variable NetBuilder::Flip(const Variable& operand, const std::vector<int>& axes) {
Instruction instr("flip", {operand});
instr.SetAttr("axes", axes);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
return CustomInstr("reverse", {operand}, {{"axis", utils::GetPositiveAxes(axes, operand->shape.size())}}).front();
}

Variable NetBuilder::Matmul(const Variable& x, const Variable& y, bool trans_x, bool trans_y, float alpha) {
Expand Down
5 changes: 4 additions & 1 deletion cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,10 @@ class NetBuilder {
const std::string& padding_algorithm = "EXPLICIT");

/**
* This API flipes the Variable x along the given axis.
* @brief This API reverse the Variable x along the given axis.
* @param x An N-D variable.
* @param axis Specify the axis to operate on the input reverse.
* @return A reversed variable with the same data type as x.
*/
Variable Flip(const Variable& operand, const std::vector<int>& axes);

Expand Down
70 changes: 0 additions & 70 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -984,76 +984,6 @@ TEST(net_build, program_execute_arange_int) {
}
}

TEST(net_build, program_execute_flip) {
const int C = 2;
const int H = 2;
const int W = 2;
const std::vector<int> axes{0};

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {C, H, W}, "Img");
Variable output = builder.Flip(input, axes);
auto program = builder.Build();

#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
std::unordered_set<std::string> fetch_ids;
auto graph = Optimize(&program, fetch_ids, target);

auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
std::vector<float> input_data = GetTensorData<float>(input_tensor, target);

runtime_program->Execute();
auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 3UL);
EXPECT_EQ(output_shape[0], C);
EXPECT_EQ(output_shape[1], H);
EXPECT_EQ(output_shape[2], W);

std::vector<float> output_data = GetTensorData<float>(output_tensor, target);
VLOG(6) << "Visualize flip input_data";
for (int c = 0; c < C; c++) {
for (int h = 0; h < H; h++) {
std::string line;
for (int w = 0; w < W; w++) {
int index = c * (H * W) + h * W + w;
line += (std::to_string(index) + ": " + std::to_string(input_data[index]) + ", ");
}
VLOG(6) << line;
}
}

VLOG(6) << "Visualize flip output_data";
for (int c = 0; c < C; c++) {
int flip_c = std::find(axes.begin(), axes.end(), 0) == axes.end() ? c : C - c - 1;
for (int h = 0; h < H; h++) {
std::string line;
int flip_h = std::find(axes.begin(), axes.end(), 1) == axes.end() ? h : H - h - 1;
for (int w = 0; w < W; w++) {
int flip_w = std::find(axes.begin(), axes.end(), 2) == axes.end() ? w : W - w - 1;
int flip_index = flip_c * H * W + flip_h * W + flip_w;
int index = c * (H * W) + h * W + w;
line += (std::to_string(index) + ": " + std::to_string(output_data[index]) + ", ");
EXPECT_EQ(input_data[index], output_data[flip_index]);
}
VLOG(6) << line;
}
}
}

TEST(net_build, program_argmax_case1) {
const int N = 4;
const int IN_C = 3;
Expand Down
2 changes: 0 additions & 2 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ core_gather_headers()

gather_srcs(cinnapi_src SRCS
gather_nd.cc
flip.cc
sort.cc
argmin.cc
argmax.cc
Expand All @@ -24,7 +23,6 @@ cc_test(test_gather_nd SRCS gather_nd_test.cc DEPS cinncore)
cc_test(test_sort SRCS sort_test.cc DEPS cinncore)
cc_test(test_argmin SRCS argmin_test.cc DEPS cinncore)
cc_test(test_argmax SRCS argmax_test.cc DEPS cinncore)
cc_test(test_flip SRCS flip_test.cc DEPS cinncore)
cc_test(test_repeat SRCS repeat_test.cc DEPS cinncore)
cc_test(test_one_hot SRCS one_hot_test.cc DEPS cinncore)
cc_test(test_lookup_table SRCS lookup_table_test.cc DEPS cinncore)
Expand Down
118 changes: 0 additions & 118 deletions cinn/hlir/op/contrib/flip.cc

This file was deleted.

32 changes: 0 additions & 32 deletions cinn/hlir/op/contrib/flip.h

This file was deleted.

67 changes: 0 additions & 67 deletions cinn/hlir/op/contrib/flip_test.cc

This file was deleted.

9 changes: 0 additions & 9 deletions cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,6 @@ std::shared_ptr<OpStrategy> StrategyForReverse(const framework::NodeAttr &attrs,
std::vector<int> axis;
if (attrs.attr_store.find("axis") != attrs.attr_store.end()) {
axis = absl::get<std::vector<int>>(attrs.attr_store.at("axis"));
CHECK(!axis.empty()) << "axis is empty! Please check setting.\n";
for (auto &e : axis) {
if (e >= static_cast<int>(output_shapes[0].size()) || e < -1 * static_cast<int>(output_shapes[0].size())) {
LOG(FATAL) << "axis is not in [0, n_dim), Please check.";
Expand All @@ -840,8 +839,6 @@ std::shared_ptr<OpStrategy> StrategyForReverse(const framework::NodeAttr &attrs,
e += output_shapes[0].size();
}
}
} else {
LOG(FATAL) << "axis is not be set! Please check.";
}

framework::CINNCompute reverse_compute([=](lang::Args args, lang::RetValue *ret) {
Expand Down Expand Up @@ -875,7 +872,6 @@ std::vector<framework::shape_t> InferShapeForReverse(const std::vector<framework
std::vector<framework::shape_t> res{inputs_shape[0]};
if (attrs.find("axis") != attrs.end()) {
auto axis = absl::get<std::vector<int>>(attrs.at("axis"));
CHECK(!axis.empty()) << "axis is empty! Please check setting.\n";
for (auto &e : axis) {
if (e >= static_cast<int>(inputs_shape[0].size()) || e < -1 * static_cast<int>(inputs_shape[0].size())) {
LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check.";
Expand All @@ -884,8 +880,6 @@ std::vector<framework::shape_t> InferShapeForReverse(const std::vector<framework
e += inputs_shape[0].size();
}
}
} else {
LOG(FATAL) << "axis is not be set! Please check.";
}
return res;
}
Expand All @@ -896,14 +890,11 @@ std::vector<std::vector<std::string>> InferLayoutForReverse(const std::vector<fr
const Target &target) {
if (attrs.attr_store.find("axis") != attrs.attr_store.end()) {
auto axis = absl::get<std::vector<int>>(attrs.attr_store.at("axis"));
CHECK(!axis.empty()) << "axis is empty! Please check setting.\n";
for (auto &e : axis) {
if (e >= static_cast<int>(input_shapes[0].size()) || e < -1 * static_cast<int>(input_shapes[0].size())) {
LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check.";
}
}
} else {
LOG(FATAL) << "axis is not be set! Please check.";
}
CHECK_EQ(input_layouts.size(), 1U) << "The input's layout size is not 1! Please check again.";
return {input_layouts, input_layouts};
Expand Down
1 change: 0 additions & 1 deletion cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ CINN_USE_REGISTER(argmin_ops)
CINN_USE_REGISTER(argmax_ops)
CINN_USE_REGISTER(reduce_ops)
CINN_USE_REGISTER(custom_call_op)
CINN_USE_REGISTER(flip_ops)
CINN_USE_REGISTER(repeat_ops)
CINN_USE_REGISTER(one_hot_ops)
CINN_USE_REGISTER(lookup_table_ops)
Expand Down
Loading

0 comments on commit a659bb2

Please sign in to comment.