From 13dbd2729871f8b05b7c709ec38eb03d29d93918 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sat, 10 Dec 2022 17:26:58 +0000 Subject: [PATCH 01/27] add hlir CSE pass --- cinn/hlir/pass/CMakeLists.txt | 2 + .../pass/common_subexpression_elimination.cc | 149 ++++++++++++++++++ .../common_subexpression_elimination_test.cc | 134 ++++++++++++++++ cinn/hlir/pass/use_pass.h | 1 + 4 files changed, 286 insertions(+) create mode 100644 cinn/hlir/pass/common_subexpression_elimination.cc create mode 100644 cinn/hlir/pass/common_subexpression_elimination_test.cc diff --git a/cinn/hlir/pass/CMakeLists.txt b/cinn/hlir/pass/CMakeLists.txt index 5c5bf36bdc..c037088f5a 100755 --- a/cinn/hlir/pass/CMakeLists.txt +++ b/cinn/hlir/pass/CMakeLists.txt @@ -11,6 +11,7 @@ gather_srcs(cinnapi_src SRCS dot_merger.cc check_fusion_accuracy_pass.cc custom_call_pass.cc + common_subexpression_elimination.cc dce_pass.cc ) @@ -29,3 +30,4 @@ if (NOT WITH_CUDA) endif() cc_test(test_dot_merger SRCS test_dot_merger.cc DEPS cinncore) cc_test(test_dce_pass SRCS dce_pass_test.cc DEPS cinncore) +cc_test(test_common_subexpression_elimination SRCS common_subexpression_elimination_test.cc DEPS cinncore) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc new file mode 100644 index 0000000000..64d0798dac --- /dev/null +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/hlir/pass/use_pass.h" +#include "cinn/hlir/pe/schedule.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace hlir { +namespace pass { + +using framework::Graph; +using framework::Node; +using framework::NodeData; +using framework::OpPatternKind; +using framework::shape_t; + +using common::GraphEdge; +using common::GraphNode; + +using GroupPtr = std::shared_ptr; +using GroupList = std::vector; + +using ShapeDict = absl::flat_hash_map; +using ConditionFunction = std::function; +using OutputToNodeMap = std::unordered_map; +using InputToNodeMap = std::unordered_map>; + +bool is_same_subexpr(Node* op1, Node* op2) { + auto op1_inputs_size = op1->inlinks_in_order().size(); + auto op2_inputs_size = op2->inlinks_in_order().size(); + if (op1_inputs_size != op2_inputs_size) { + return false; + } + auto op1_attrs_size = op1->attrs.attr_store.size(); + auto op2_attrs_size = op2->attrs.attr_store.size(); + if (op1_attrs_size != op2_attrs_size) { + return false; + } + for (int i = 0; i < op1_inputs_size; ++i) { + auto* op1_source_node = op1->inlinks_in_order()[i]->source(); + auto* op2_source_node = op2->inlinks_in_order()[i]->source(); + + if (op1_source_node->id() != op2_source_node->id()) { + return false; + } + } + for (auto& attr : op1->attrs.attr_store) { + if (!op2->attrs.attr_store.count(attr.first) || op2->attrs.attr_store[attr.first] != attr.second) { + return false; + } + } + return true; +} + +void remove_node(framework::Graph* graph, GraphNode* node) { + auto inlinks = node->inlinks(); + for (auto& link : inlinks) { + link->source()->UnLinkSingleTo(link->sink()); + } + auto outlinks = node->outlinks(); + for (auto& link : outlinks) { + link->source()->UnLinkSingleTo(link->sink()); + } + graph->DropNode(node); +} + +void CommonSubexpressionEliminationPass(Graph* graph) { + VLOG(3) << "CommonSubexpressionEliminationPass...!"; + std::unordered_map> expr_map; + std::unordered_map results; + int remove_num = 0; + OutputToNodeMap out2node; + InputToNodeMap in2node; + auto store_nodes = std::get<0>(graph->topological_order()); + + for (auto& graph_node : store_nodes) { + auto node = graph_node->safe_as(); + if (node) { + for (auto& out_edge : node->outlinks_in_order(true)) { + auto* sink_node = out_edge->sink()->safe_as(); + out2node[sink_node->id()] = node; + } + for (auto& in_edge : node->inlinks_in_order(true)) { + auto* source_node = in_edge->source()->safe_as(); + in2node[source_node->id()].insert(node); + } + } + } + for (auto& graph_node : store_nodes) { + auto node = graph_node->safe_as(); + if (node) { + auto& node_type = node->op()->name; + auto& candidates = expr_map[node_type]; + bool found = false; + for (auto* candidate_node : candidates) { + if (!is_same_subexpr(node, candidate_node)) continue; + found = true; + for (int k = 0; k < node->outlinks_in_order(true).size(); ++k) { + auto* sink_node = node->outlinks_in_order(true)[k]->sink()->safe_as(); + auto* candidate_sink_node = candidate_node->outlinks_in_order(true)[k]->sink()->safe_as(); + for (auto out_node : in2node[sink_node->id()]) { + sink_node->UnLinkSingleTo(out_node); + candidate_sink_node->LinkTo(out_node); + } + } + remove_node(graph, node); + remove_num++; + break; + } + if (!found) { + expr_map[node_type].push_back(node); + } + } + } + LOG(INFO) << "Total remove " << remove_num << " node."; + VLOG(3) << "Total remove " << remove_num << " node."; + VLOG(3) << "CommonSubexpressionEliminationPass Finish...!"; +} +} // namespace pass +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(CommonSubexpressionEliminationPass) { + CINN_REGISTER_PASS(CommonSubexpressionEliminationPass) + .describe("This pass will remove these same sub-expression.") + .set_change_structure(false) + .set_body(cinn::hlir::pass::CommonSubexpressionEliminationPass); + + return true; +} diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc new file mode 100644 index 0000000000..88b8da15c0 --- /dev/null +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 202 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "cinn/cinn.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/hlir/op/use_ops.h" +#include "cinn/hlir/pass/use_pass.h" +#include "cinn/utils/data_util.h" + +DEFINE_string(model_dir, "", ""); + +namespace cinn { +namespace frontend { + +using hlir::framework::Scope; +using utils::Join; + +TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { + Placeholder A(Float(32), {32, 16}, "A"); + Placeholder B(Float(32), {32, 1}, "B", true); + + Program program; + auto add_1 = program.add(A, B); + auto add_2 = program.add(A, B); + auto add = program.add(add_1, add_2); + + Target target = common::DefaultTarget(); + program.SetInputs({A, B}); + program.Validate(); + LOG(INFO) << "Program:\n" << program; + auto graph = std::make_shared(program, target); + LOG(INFO) << "graph:\n" << graph->Visualize(); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); + hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); + auto scope = BuildScope(target, graph); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); + ASSERT_EQ(prerun_instrs.size(), 0); + ASSERT_EQ(run_instrs.size(), 2); + + scope->Var("A"); + scope->Var("B"); + + auto A1 = scope->GetTensor("A"); + auto B1 = scope->GetTensor("B"); + SetRandData(A1, target); + SetRandData(B1, target); + + runtime_program->Execute(); + LOG(INFO) << "Program:\n" << program; + LOG(INFO) << "graph:\n" << graph->Visualize(); +} + +TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { + Placeholder A(Float(32), {32, 16}, "A"); + Placeholder B(Float(32), {32, 1}, "B", true); + + Program program; + auto sub_1 = program.elementwise_sub(A, A); + auto sub_2 = program.elementwise_sub(A, A); + auto add_1 = program.add(B, sub_1); + auto add_2 = program.add(B, sub_2); + auto add = program.add(add_1, add_2); + + Target target = common::DefaultTarget(); + program.SetInputs({A, B}); + program.Validate(); + LOG(INFO) << "Program:\n" << program; + auto graph = std::make_shared(program, target); + LOG(INFO) << "graph:\n" << graph->Visualize(); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); + hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); + hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); + auto scope = BuildScope(target, graph); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); + ASSERT_EQ(prerun_instrs.size(), 0); + ASSERT_EQ(run_instrs.size(), 3); + + scope->Var("A"); + scope->Var("B"); + + auto A1 = scope->GetTensor("A"); + auto B1 = scope->GetTensor("B"); + SetRandData(A1, target); + SetRandData(B1, target); + + runtime_program->Execute(); + LOG(INFO) << "Program:\n" << program; + LOG(INFO) << "graph:\n" << graph->Visualize(); +} + +} // namespace frontend +} // namespace cinn diff --git a/cinn/hlir/pass/use_pass.h b/cinn/hlir/pass/use_pass.h index d22b52c78b..5dec555d63 100644 --- a/cinn/hlir/pass/use_pass.h +++ b/cinn/hlir/pass/use_pass.h @@ -27,3 +27,4 @@ CINN_USE_REGISTER(OpFusionPass) CINN_USE_REGISTER(FusionMergePass) CINN_USE_REGISTER(CheckFusionAccuracyPass) CINN_USE_REGISTER(CustomCallPass) +CINN_USE_REGISTER(CommonSubexpressionEliminationPass) From 2d5b3baed8c5f79a9d5034b664e845573e3548bb Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Fri, 16 Dec 2022 13:49:26 +0000 Subject: [PATCH 02/27] replace loop by 'std::all_of' and repeat remove --- .../pass/common_subexpression_elimination.cc | 58 ++++++++++--------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 64d0798dac..b7fefa5a75 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -20,7 +20,6 @@ #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/pass.h" #include "cinn/hlir/pass/use_pass.h" -#include "cinn/hlir/pe/schedule.h" #include "cinn/utils/string.h" namespace cinn { @@ -41,7 +40,6 @@ using GroupList = std::vector; using ShapeDict = absl::flat_hash_map; using ConditionFunction = std::function; -using OutputToNodeMap = std::unordered_map; using InputToNodeMap = std::unordered_map>; bool is_same_subexpr(Node* op1, Node* op2) { @@ -63,12 +61,11 @@ bool is_same_subexpr(Node* op1, Node* op2) { return false; } } - for (auto& attr : op1->attrs.attr_store) { + return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) { if (!op2->attrs.attr_store.count(attr.first) || op2->attrs.attr_store[attr.first] != attr.second) { return false; } - } - return true; + }); } void remove_node(framework::Graph* graph, GraphNode* node) { @@ -81,30 +78,12 @@ void remove_node(framework::Graph* graph, GraphNode* node) { link->source()->UnLinkSingleTo(link->sink()); } graph->DropNode(node); + LOG(INFO) << "remove " << node->id() << " node."; } -void CommonSubexpressionEliminationPass(Graph* graph) { - VLOG(3) << "CommonSubexpressionEliminationPass...!"; +int remove_common_subexpression(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { std::unordered_map> expr_map; - std::unordered_map results; int remove_num = 0; - OutputToNodeMap out2node; - InputToNodeMap in2node; - auto store_nodes = std::get<0>(graph->topological_order()); - - for (auto& graph_node : store_nodes) { - auto node = graph_node->safe_as(); - if (node) { - for (auto& out_edge : node->outlinks_in_order(true)) { - auto* sink_node = out_edge->sink()->safe_as(); - out2node[sink_node->id()] = node; - } - for (auto& in_edge : node->inlinks_in_order(true)) { - auto* source_node = in_edge->source()->safe_as(); - in2node[source_node->id()].insert(node); - } - } - } for (auto& graph_node : store_nodes) { auto node = graph_node->safe_as(); if (node) { @@ -131,7 +110,34 @@ void CommonSubexpressionEliminationPass(Graph* graph) { } } } - LOG(INFO) << "Total remove " << remove_num << " node."; + return remove_num; +} + +void CommonSubexpressionEliminationPass(Graph* graph) { + VLOG(3) << "CommonSubexpressionEliminationPass...!"; + std::unordered_map> expr_map; + InputToNodeMap in2node; + auto store_nodes = std::get<0>(graph->topological_order()); + + for (auto& graph_node : store_nodes) { + auto node = graph_node->safe_as(); + if (node) { + for (auto& in_edge : node->inlinks_in_order(true)) { + auto* source_node = in_edge->source()->safe_as(); + in2node[source_node->id()].insert(node); + } + } + } + + int remove_num = 0; + int last_remove_num = 0; + while (last_remove_num || !remove_num) { + last_remove_num = remove_common_subexpression(graph, store_nodes, in2node); + if (last_remove_num) { + remove_num += last_remove_num; + store_nodes = std::get<0>(graph->topological_order()); + } + } VLOG(3) << "Total remove " << remove_num << " node."; VLOG(3) << "CommonSubexpressionEliminationPass Finish...!"; } From 824fa0b89fc2789cceba96490e23ecaa7ef7d6e4 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Fri, 16 Dec 2022 14:05:56 +0000 Subject: [PATCH 03/27] optimize is_same_subexpr --- .../hlir/pass/common_subexpression_elimination.cc | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index b7fefa5a75..b3454311c2 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -56,7 +56,6 @@ bool is_same_subexpr(Node* op1, Node* op2) { for (int i = 0; i < op1_inputs_size; ++i) { auto* op1_source_node = op1->inlinks_in_order()[i]->source(); auto* op2_source_node = op2->inlinks_in_order()[i]->source(); - if (op1_source_node->id() != op2_source_node->id()) { return false; } @@ -65,6 +64,7 @@ bool is_same_subexpr(Node* op1, Node* op2) { if (!op2->attrs.attr_store.count(attr.first) || op2->attrs.attr_store[attr.first] != attr.second) { return false; } + return true; }); } @@ -129,16 +129,11 @@ void CommonSubexpressionEliminationPass(Graph* graph) { } } - int remove_num = 0; - int last_remove_num = 0; - while (last_remove_num || !remove_num) { - last_remove_num = remove_common_subexpression(graph, store_nodes, in2node); - if (last_remove_num) { - remove_num += last_remove_num; - store_nodes = std::get<0>(graph->topological_order()); - } + int remove_num = remove_common_subexpression(graph, store_nodes, in2node); + while (remove_num) { + store_nodes = std::get<0>(graph->topological_order()); + remove_num = remove_common_subexpression(graph, store_nodes, in2node); } - VLOG(3) << "Total remove " << remove_num << " node."; VLOG(3) << "CommonSubexpressionEliminationPass Finish...!"; } } // namespace pass From cafa37707d5f9686d1e8e928fede6ccd13e6602b Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Fri, 16 Dec 2022 14:10:05 +0000 Subject: [PATCH 04/27] add cse pass in optimize.cc --- cinn/frontend/optimize.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cinn/frontend/optimize.cc b/cinn/frontend/optimize.cc index 3e903c1b4a..ee0b1000af 100644 --- a/cinn/frontend/optimize.cc +++ b/cinn/frontend/optimize.cc @@ -31,6 +31,7 @@ DECLARE_bool(cinn_use_fill_constant_folding); DECLARE_bool(cinn_use_op_fusion); DECLARE_bool(cinn_use_cudnn_conv); DECLARE_bool(cinn_use_cublas_gemm); +DECLARE_bool(cinn_use_common_subexpression_elimination); DECLARE_bool(cinn_check_fusion_accuracy_pass); namespace cinn { @@ -83,6 +84,10 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { options.graph_passes.push_back("BuildNonFusedGroupsPass"); } + if (FLAGS_cinn_use_common_subexpression_elimination) { + options.graph_passes.push_back("CommonSubexpressionEliminationPass"); + } + // WARNING: the pass must be the last pass !!! if (FLAGS_cinn_check_fusion_accuracy_pass) { // Check the correct of fusion kernels, if the results not satisfied 'allclose(rtol=1e-05f, atol=1e-08f)', report From d3eb645aa17b5b80f79bee1f9c6e915af5d79db7 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Fri, 16 Dec 2022 14:12:15 +0000 Subject: [PATCH 05/27] use emplace_back instead of push_back --- cinn/frontend/optimize.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cinn/frontend/optimize.cc b/cinn/frontend/optimize.cc index ee0b1000af..d0f72af77f 100644 --- a/cinn/frontend/optimize.cc +++ b/cinn/frontend/optimize.cc @@ -65,27 +65,27 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { options.graph_passes = {}; #ifdef CINN_WITH_CUDA if (FLAGS_cinn_use_cublas_gemm) { - options.graph_passes.push_back("MatmulToCublasCustomCallPass"); + options.graph_passes.emplace_back("MatmulToCublasCustomCallPass"); } options.graph_passes.emplace_back("GaussianRandomToCustomCallPass"); options.graph_passes.emplace_back("UniformRandomToCustomCallPass"); options.graph_passes.emplace_back("CholeskyToCustomCallPass"); #ifdef CINN_WITH_CUDNN if (FLAGS_cinn_use_cudnn_conv) { - options.graph_passes.push_back("ConvToCudnnCustomCallPass"); + options.graph_passes.emplace_back("ConvToCudnnCustomCallPass"); } #endif #endif if (FLAGS_cinn_use_op_fusion) { - options.graph_passes.push_back("OpFusionPass"); - options.graph_passes.push_back("FusionMergePass"); + options.graph_passes.emplace_back("OpFusionPass"); + options.graph_passes.emplace_back("FusionMergePass"); } else { - options.graph_passes.push_back("BuildNonFusedGroupsPass"); + options.graph_passes.emplace_back("BuildNonFusedGroupsPass"); } if (FLAGS_cinn_use_common_subexpression_elimination) { - options.graph_passes.push_back("CommonSubexpressionEliminationPass"); + options.graph_passes.emplace_back("CommonSubexpressionEliminationPass"); } // WARNING: the pass must be the last pass !!! From eb8e8954e17f4e1344418d592229b4ef68e0cfc3 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Fri, 16 Dec 2022 16:20:29 +0000 Subject: [PATCH 06/27] fix bug --- .../pass/common_subexpression_elimination.cc | 11 +++-- .../common_subexpression_elimination_test.cc | 44 +++++++++++++++++-- cinn/runtime/flags.cc | 4 ++ 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index b3454311c2..755638fd93 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -53,9 +53,11 @@ bool is_same_subexpr(Node* op1, Node* op2) { if (op1_attrs_size != op2_attrs_size) { return false; } + auto op1_inlinks = op1->inlinks_in_order(true); + auto op2_inlinks = op2->inlinks_in_order(true); for (int i = 0; i < op1_inputs_size; ++i) { - auto* op1_source_node = op1->inlinks_in_order()[i]->source(); - auto* op2_source_node = op2->inlinks_in_order()[i]->source(); + auto* op1_source_node = op1_inlinks[i]->source(); + auto* op2_source_node = op2_inlinks[i]->source(); if (op1_source_node->id() != op2_source_node->id()) { return false; } @@ -96,9 +98,12 @@ int remove_common_subexpression(Graph* graph, std::vector& store_nod for (int k = 0; k < node->outlinks_in_order(true).size(); ++k) { auto* sink_node = node->outlinks_in_order(true)[k]->sink()->safe_as(); auto* candidate_sink_node = candidate_node->outlinks_in_order(true)[k]->sink()->safe_as(); - for (auto out_node : in2node[sink_node->id()]) { + auto out_nodes = in2node[sink_node->id()]; + for (auto out_node : out_nodes) { sink_node->UnLinkSingleTo(out_node); candidate_sink_node->LinkTo(out_node); + out_nodes.erase(node); + out_nodes.insert(candidate_node); } } remove_node(graph, node); diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index 88b8da15c0..a6351bd357 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -83,8 +83,6 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { SetRandData(B1, target); runtime_program->Execute(); - LOG(INFO) << "Program:\n" << program; - LOG(INFO) << "graph:\n" << graph->Visualize(); } TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { @@ -107,7 +105,6 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { hlir::framework::ApplyPass(graph.get(), "InferShape"); hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); - hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); auto scope = BuildScope(target, graph); hlir::framework::GraphCompiler gc(target, scope, graph); @@ -126,7 +123,48 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { SetRandData(B1, target); runtime_program->Execute(); +} + +TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { + Placeholder A(Float(32), {32, 16}, "A"); + Placeholder B(Float(32), {32, 1}, "B", true); + + Program program; + auto sub_1 = program.elementwise_sub(A, A); + auto sub_2 = program.elementwise_sub(A, A); + auto const_1 = program.fill_constant({32, 16}, 1.0f, "", false, "const1"); + auto const_2 = program.fill_constant({32, 16}, 1.0f, "", false, "const2"); + auto const_3 = program.fill_constant({32, 16}, 2.0f, "", false, "const3"); + auto out1 = program.add(const_1, const_3); + auto out2 = program.add(const_2, const_3); + + Target target = common::DefaultTarget(); + program.SetInputs({A, B}); + program.Validate(); LOG(INFO) << "Program:\n" << program; + auto graph = std::make_shared(program, target); + LOG(INFO) << "graph:\n" << graph->Visualize(); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); + hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); + auto scope = BuildScope(target, graph); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); + ASSERT_EQ(prerun_instrs.size(), 0); + ASSERT_EQ(run_instrs.size(), 5); + + scope->Var("A"); + scope->Var("B"); + + auto A1 = scope->GetTensor("A"); + auto B1 = scope->GetTensor("B"); + SetRandData(A1, target); + SetRandData(B1, target); + + runtime_program->Execute(); LOG(INFO) << "graph:\n" << graph->Visualize(); } diff --git a/cinn/runtime/flags.cc b/cinn/runtime/flags.cc index 16e40a3149..be93eb76cf 100644 --- a/cinn/runtime/flags.cc +++ b/cinn/runtime/flags.cc @@ -43,6 +43,10 @@ DEFINE_bool(cinn_use_cudnn_conv, BoolFromEnv("FLAGS_cinn_use_cudnn_conv", true), DEFINE_bool(cinn_use_cublas_gemm, BoolFromEnv("FLAGS_cinn_use_cublas_gemm", true), "Whether to use cublas gemm."); +DEFINE_bool(cinn_use_common_subexpression_elimination, + BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", true), + "Whether to use common subexpression elimination pass."); + DEFINE_bool(cinn_use_fill_constant_folding, BoolFromEnv("FLAGS_cinn_use_fill_constant_folding", false), "Whether use the FillConstantFolding pass."); From 668d13eee43b47aa885331eb091a9a05943bae96 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sat, 17 Dec 2022 10:47:31 +0000 Subject: [PATCH 07/27] fix link order --- .../pass/common_subexpression_elimination.cc | 36 ++++++++++++++++--- .../common_subexpression_elimination_test.cc | 2 +- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 755638fd93..eeba9acf7a 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -42,7 +42,7 @@ using ShapeDict = absl::flat_hash_map; using ConditionFunction = std::function; using InputToNodeMap = std::unordered_map>; -bool is_same_subexpr(Node* op1, Node* op2) { +bool is_same_subexpression(Node* op1, Node* op2) { auto op1_inputs_size = op1->inlinks_in_order().size(); auto op2_inputs_size = op2->inlinks_in_order().size(); if (op1_inputs_size != op2_inputs_size) { @@ -58,6 +58,10 @@ bool is_same_subexpr(Node* op1, Node* op2) { for (int i = 0; i < op1_inputs_size; ++i) { auto* op1_source_node = op1_inlinks[i]->source(); auto* op2_source_node = op2_inlinks[i]->source(); + LOG(INFO) << op1_source_node->id(); + LOG(INFO) << op2_source_node->id(); + LOG(INFO) << op1_inlinks[1]->source()->id(); + LOG(INFO) << op2_inlinks[1]->source()->id(); if (op1_source_node->id() != op2_source_node->id()) { return false; } @@ -83,6 +87,22 @@ void remove_node(framework::Graph* graph, GraphNode* node) { LOG(INFO) << "remove " << node->id() << " node."; } +void replace_inlinks(NodeData* src_new, NodeData* src_old, Node* trt) { + std::vector in_nodes; + for (auto& in_link : trt->inlinks_in_order(true)) { + auto* in_node = in_link->source()->safe_as(); + if (in_node->id() == src_old->id()) { + in_nodes.emplace_back(src_new); + } else { + in_nodes.emplace_back(in_node); + } + in_node->UnLinkSingleTo(trt); + } + for (auto in_node : in_nodes) { + in_node->LinkTo(trt); + } +} + int remove_common_subexpression(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { std::unordered_map> expr_map; int remove_num = 0; @@ -93,15 +113,23 @@ int remove_common_subexpression(Graph* graph, std::vector& store_nod auto& candidates = expr_map[node_type]; bool found = false; for (auto* candidate_node : candidates) { - if (!is_same_subexpr(node, candidate_node)) continue; + if (!is_same_subexpression(node, candidate_node)) continue; found = true; for (int k = 0; k < node->outlinks_in_order(true).size(); ++k) { auto* sink_node = node->outlinks_in_order(true)[k]->sink()->safe_as(); auto* candidate_sink_node = candidate_node->outlinks_in_order(true)[k]->sink()->safe_as(); auto out_nodes = in2node[sink_node->id()]; for (auto out_node : out_nodes) { - sink_node->UnLinkSingleTo(out_node); - candidate_sink_node->LinkTo(out_node); + auto op2_inlinks = out_node->inlinks_in_order(true); + LOG(INFO) << op2_inlinks[0]->source()->id(); + LOG(INFO) << op2_inlinks[1]->source()->id(); + + replace_inlinks(candidate_sink_node, sink_node, out_node); + + op2_inlinks = out_node->inlinks_in_order(true); + LOG(INFO) << op2_inlinks[0]->source()->id(); + LOG(INFO) << op2_inlinks[1]->source()->id(); + out_nodes.erase(node); out_nodes.insert(candidate_node); } diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index a6351bd357..fd72e63b4b 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -154,7 +154,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { auto& prerun_instrs = runtime_program->GetPreRunInstructions(); auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); - ASSERT_EQ(run_instrs.size(), 5); + ASSERT_EQ(run_instrs.size(), 4); scope->Var("A"); scope->Var("B"); From fa19f5b3770ea235e910d3aaad0e7959b7d85cdb Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sat, 17 Dec 2022 11:22:15 +0000 Subject: [PATCH 08/27] remove some unused code --- .../pass/common_subexpression_elimination.cc | 30 +++++-------------- .../common_subexpression_elimination_test.cc | 1 - 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index eeba9acf7a..700c510a2f 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -29,22 +29,17 @@ namespace pass { using framework::Graph; using framework::Node; using framework::NodeData; -using framework::OpPatternKind; -using framework::shape_t; using common::GraphEdge; using common::GraphNode; -using GroupPtr = std::shared_ptr; -using GroupList = std::vector; - -using ShapeDict = absl::flat_hash_map; -using ConditionFunction = std::function; -using InputToNodeMap = std::unordered_map>; +using InputToNodeMap = std::unordered_map>; bool is_same_subexpression(Node* op1, Node* op2) { - auto op1_inputs_size = op1->inlinks_in_order().size(); - auto op2_inputs_size = op2->inlinks_in_order().size(); + auto op1_inlinks = op1->inlinks_in_order(true); + auto op2_inlinks = op2->inlinks_in_order(true); + auto op1_inputs_size = op1_inlinks.size(); + auto op2_inputs_size = op2_inlinks.size(); if (op1_inputs_size != op2_inputs_size) { return false; } @@ -53,15 +48,13 @@ bool is_same_subexpression(Node* op1, Node* op2) { if (op1_attrs_size != op2_attrs_size) { return false; } - auto op1_inlinks = op1->inlinks_in_order(true); - auto op2_inlinks = op2->inlinks_in_order(true); for (int i = 0; i < op1_inputs_size; ++i) { + LOG(INFO) << op1_inlinks[1]->source()->id(); + LOG(INFO) << op2_inlinks[1]->source()->id(); auto* op1_source_node = op1_inlinks[i]->source(); auto* op2_source_node = op2_inlinks[i]->source(); LOG(INFO) << op1_source_node->id(); LOG(INFO) << op2_source_node->id(); - LOG(INFO) << op1_inlinks[1]->source()->id(); - LOG(INFO) << op2_inlinks[1]->source()->id(); if (op1_source_node->id() != op2_source_node->id()) { return false; } @@ -120,16 +113,7 @@ int remove_common_subexpression(Graph* graph, std::vector& store_nod auto* candidate_sink_node = candidate_node->outlinks_in_order(true)[k]->sink()->safe_as(); auto out_nodes = in2node[sink_node->id()]; for (auto out_node : out_nodes) { - auto op2_inlinks = out_node->inlinks_in_order(true); - LOG(INFO) << op2_inlinks[0]->source()->id(); - LOG(INFO) << op2_inlinks[1]->source()->id(); - replace_inlinks(candidate_sink_node, sink_node, out_node); - - op2_inlinks = out_node->inlinks_in_order(true); - LOG(INFO) << op2_inlinks[0]->source()->id(); - LOG(INFO) << op2_inlinks[1]->source()->id(); - out_nodes.erase(node); out_nodes.insert(candidate_node); } diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index fd72e63b4b..f219d8be21 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -155,7 +155,6 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); ASSERT_EQ(run_instrs.size(), 4); - scope->Var("A"); scope->Var("B"); From bf9c6fa5a07dc11a26526ff3be60459e8b6102ca Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sat, 17 Dec 2022 12:03:21 +0000 Subject: [PATCH 09/27] fix bug --- cinn/frontend/optimize.cc | 6 +-- .../pass/common_subexpression_elimination.cc | 50 ++++++++++--------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/cinn/frontend/optimize.cc b/cinn/frontend/optimize.cc index d0f72af77f..fc427f8c38 100644 --- a/cinn/frontend/optimize.cc +++ b/cinn/frontend/optimize.cc @@ -84,9 +84,9 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { options.graph_passes.emplace_back("BuildNonFusedGroupsPass"); } - if (FLAGS_cinn_use_common_subexpression_elimination) { - options.graph_passes.emplace_back("CommonSubexpressionEliminationPass"); - } + // if (FLAGS_cinn_use_common_subexpression_elimination) { + // options.graph_passes.emplace_back("CommonSubexpressionEliminationPass"); + // } // WARNING: the pass must be the last pass !!! if (FLAGS_cinn_check_fusion_accuracy_pass) { diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 700c510a2f..741ccad5ca 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -36,10 +36,10 @@ using common::GraphNode; using InputToNodeMap = std::unordered_map>; bool is_same_subexpression(Node* op1, Node* op2) { - auto op1_inlinks = op1->inlinks_in_order(true); - auto op2_inlinks = op2->inlinks_in_order(true); - auto op1_inputs_size = op1_inlinks.size(); - auto op2_inputs_size = op2_inlinks.size(); + auto op1_in_edges = op1->inlinks_in_order(true); + auto op2_in_edges = op2->inlinks_in_order(true); + auto op1_inputs_size = op1_in_edges.size(); + auto op2_inputs_size = op2_in_edges.size(); if (op1_inputs_size != op2_inputs_size) { return false; } @@ -49,12 +49,10 @@ bool is_same_subexpression(Node* op1, Node* op2) { return false; } for (int i = 0; i < op1_inputs_size; ++i) { - LOG(INFO) << op1_inlinks[1]->source()->id(); - LOG(INFO) << op2_inlinks[1]->source()->id(); - auto* op1_source_node = op1_inlinks[i]->source(); - auto* op2_source_node = op2_inlinks[i]->source(); - LOG(INFO) << op1_source_node->id(); - LOG(INFO) << op2_source_node->id(); + auto* op1_source_node = op1_in_edges[i]->source()->safe_as(); + auto* op2_source_node = op2_in_edges[i]->source()->safe_as(); + CHECK(op1_source_node); + CHECK(op2_source_node); if (op1_source_node->id() != op2_source_node->id()) { return false; } @@ -67,23 +65,26 @@ bool is_same_subexpression(Node* op1, Node* op2) { }); } -void remove_node(framework::Graph* graph, GraphNode* node) { - auto inlinks = node->inlinks(); - for (auto& link : inlinks) { - link->source()->UnLinkSingleTo(link->sink()); +void remove_node(framework::Graph* graph, Node* node) { + auto in_edges = node->inlinks(); + for (auto& edge : in_edges) { + auto* in_node = edge->source()->safe_as(); + in_node->UnLinkSingleTo(node); } - auto outlinks = node->outlinks(); - for (auto& link : outlinks) { - link->source()->UnLinkSingleTo(link->sink()); + auto out_edges = node->outlinks(); + for (auto& edge : out_edges) { + auto* out_node = edge->sink()->safe_as(); + CHECK(out_node); + node->UnLinkSingleTo(out_node); } graph->DropNode(node); LOG(INFO) << "remove " << node->id() << " node."; } -void replace_inlinks(NodeData* src_new, NodeData* src_old, Node* trt) { +void replace_node(NodeData* src_new, NodeData* src_old, Node* trt) { std::vector in_nodes; - for (auto& in_link : trt->inlinks_in_order(true)) { - auto* in_node = in_link->source()->safe_as(); + for (auto& in_edge : trt->inlinks_in_order(true)) { + auto* in_node = in_edge->source()->safe_as(); if (in_node->id() == src_old->id()) { in_nodes.emplace_back(src_new); } else { @@ -109,11 +110,14 @@ int remove_common_subexpression(Graph* graph, std::vector& store_nod if (!is_same_subexpression(node, candidate_node)) continue; found = true; for (int k = 0; k < node->outlinks_in_order(true).size(); ++k) { + CHECK(node->outlinks_in_order(true).size() == candidate_node->outlinks_in_order(true).size()); auto* sink_node = node->outlinks_in_order(true)[k]->sink()->safe_as(); auto* candidate_sink_node = candidate_node->outlinks_in_order(true)[k]->sink()->safe_as(); - auto out_nodes = in2node[sink_node->id()]; + CHECK(sink_node); + CHECK(candidate_sink_node); + auto out_nodes = in2node[sink_node->id()]; for (auto out_node : out_nodes) { - replace_inlinks(candidate_sink_node, sink_node, out_node); + replace_node(candidate_sink_node, sink_node, out_node); out_nodes.erase(node); out_nodes.insert(candidate_node); } @@ -129,7 +133,7 @@ int remove_common_subexpression(Graph* graph, std::vector& store_nod } return remove_num; } - +// void CommonSubexpressionEliminationPass(Graph* graph) { VLOG(3) << "CommonSubexpressionEliminationPass...!"; std::unordered_map> expr_map; From 0d630e03efbbba4d6ae3ace02a4693fcf5b85320 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Tue, 20 Dec 2022 13:32:18 +0000 Subject: [PATCH 10/27] unify function name --- cinn/frontend/optimize.cc | 1 - .../pass/common_subexpression_elimination.cc | 72 ++++++++++++++----- .../common_subexpression_elimination_test.cc | 7 +- 3 files changed, 60 insertions(+), 20 deletions(-) diff --git a/cinn/frontend/optimize.cc b/cinn/frontend/optimize.cc index fc427f8c38..442114928f 100644 --- a/cinn/frontend/optimize.cc +++ b/cinn/frontend/optimize.cc @@ -94,7 +94,6 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { // error and exited. options.graph_passes.emplace_back("CheckFusionAccuracyPass"); } - return options; } diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 741ccad5ca..39424b1fa7 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -35,7 +35,26 @@ using common::GraphNode; using InputToNodeMap = std::unordered_map>; -bool is_same_subexpression(Node* op1, Node* op2) { +std::unordered_set unordered_ops = { + "elementwise_add", + "elementwise_mul", + "max", + "min", + "logical_and", + "logical_or", + "logical_xor", + "equal", + "not_equal", + "bitwise_or", + "bitwise_xor", + "bitwise_and", + "reduce_sum", + "reduce_prod", + "reduce_max", + "reduce_min", +}; + +bool IsSameSubexpression(Node* op1, Node* op2, const absl::flat_hash_map& shape_dict) { auto op1_in_edges = op1->inlinks_in_order(true); auto op2_in_edges = op2->inlinks_in_order(true); auto op1_inputs_size = op1_in_edges.size(); @@ -48,13 +67,31 @@ bool is_same_subexpression(Node* op1, Node* op2) { if (op1_attrs_size != op2_attrs_size) { return false; } - for (int i = 0; i < op1_inputs_size; ++i) { - auto* op1_source_node = op1_in_edges[i]->source()->safe_as(); - auto* op2_source_node = op2_in_edges[i]->source()->safe_as(); - CHECK(op1_source_node); - CHECK(op2_source_node); - if (op1_source_node->id() != op2_source_node->id()) { - return false; + if (unordered_ops.count(op1->op()->name)) { + for (auto& op1_edge : op1_in_edges) { + auto* op1_source_node = op1_edge->source()->safe_as(); + CHECK(op1_source_node); + bool op1_equal_op2 = std::any_of(op2_in_edges.begin(), op2_in_edges.end(), [&](common::Shared& edge) { + auto* op2_source_node = edge->source()->safe_as(); + CHECK(op2_source_node); + if (op1_source_node->id() == op2_source_node->id()) { + return true; + } + return false; + }); + if (!op1_equal_op2) { + return false; + } + } + } else { + for (int i = 0; i < op1_inputs_size; ++i) { + auto* op1_source_node = op1_in_edges[i]->source()->safe_as(); + auto* op2_source_node = op2_in_edges[i]->source()->safe_as(); + CHECK(op1_source_node); + CHECK(op2_source_node); + if (op1_source_node->id() != op2_source_node->id()) { + return false; + } } } return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) { @@ -65,7 +102,7 @@ bool is_same_subexpression(Node* op1, Node* op2) { }); } -void remove_node(framework::Graph* graph, Node* node) { +void RemoveNode(framework::Graph* graph, Node* node) { auto in_edges = node->inlinks(); for (auto& edge : in_edges) { auto* in_node = edge->source()->safe_as(); @@ -81,7 +118,7 @@ void remove_node(framework::Graph* graph, Node* node) { LOG(INFO) << "remove " << node->id() << " node."; } -void replace_node(NodeData* src_new, NodeData* src_old, Node* trt) { +void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) { std::vector in_nodes; for (auto& in_edge : trt->inlinks_in_order(true)) { auto* in_node = in_edge->source()->safe_as(); @@ -97,9 +134,10 @@ void replace_node(NodeData* src_new, NodeData* src_old, Node* trt) { } } -int remove_common_subexpression(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { +int CommonSubexpressionElimination(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { std::unordered_map> expr_map; - int remove_num = 0; + auto& shape_dict = graph->GetAttrs>("infershape"); + int remove_num = 0; for (auto& graph_node : store_nodes) { auto node = graph_node->safe_as(); if (node) { @@ -107,7 +145,7 @@ int remove_common_subexpression(Graph* graph, std::vector& store_nod auto& candidates = expr_map[node_type]; bool found = false; for (auto* candidate_node : candidates) { - if (!is_same_subexpression(node, candidate_node)) continue; + if (!IsSameSubexpression(node, candidate_node, shape_dict)) continue; found = true; for (int k = 0; k < node->outlinks_in_order(true).size(); ++k) { CHECK(node->outlinks_in_order(true).size() == candidate_node->outlinks_in_order(true).size()); @@ -117,12 +155,12 @@ int remove_common_subexpression(Graph* graph, std::vector& store_nod CHECK(candidate_sink_node); auto out_nodes = in2node[sink_node->id()]; for (auto out_node : out_nodes) { - replace_node(candidate_sink_node, sink_node, out_node); + ReplaceNode(candidate_sink_node, sink_node, out_node); out_nodes.erase(node); out_nodes.insert(candidate_node); } } - remove_node(graph, node); + RemoveNode(graph, node); remove_num++; break; } @@ -150,10 +188,10 @@ void CommonSubexpressionEliminationPass(Graph* graph) { } } - int remove_num = remove_common_subexpression(graph, store_nodes, in2node); + int remove_num = CommonSubexpressionElimination(graph, store_nodes, in2node); while (remove_num) { store_nodes = std::get<0>(graph->topological_order()); - remove_num = remove_common_subexpression(graph, store_nodes, in2node); + remove_num = CommonSubexpressionElimination(graph, store_nodes, in2node); } VLOG(3) << "CommonSubexpressionEliminationPass Finish...!"; } diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index f219d8be21..133c7d2d7c 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -53,8 +53,11 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { Program program; auto add_1 = program.add(A, B); - auto add_2 = program.add(A, B); + auto add_2 = program.add(B, A); auto add = program.add(add_1, add_2); + auto max_1 = program.reduce_max(add, {-1}, false); + auto max_2 = program.reduce_max(add, {1}, false); + auto max = program.reduce_max(add, {0}, true); Target target = common::DefaultTarget(); program.SetInputs({A, B}); @@ -93,7 +96,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { auto sub_1 = program.elementwise_sub(A, A); auto sub_2 = program.elementwise_sub(A, A); auto add_1 = program.add(B, sub_1); - auto add_2 = program.add(B, sub_2); + auto add_2 = program.add(sub_2, B); auto add = program.add(add_1, add_2); Target target = common::DefaultTarget(); From b9a7803eac72ddc12ba6d1a6a5ec16df92898db5 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Wed, 21 Dec 2022 05:23:41 +0000 Subject: [PATCH 11/27] fix special arg --- .../pass/common_subexpression_elimination.cc | 37 +++++++++++++++---- .../common_subexpression_elimination_test.cc | 17 ++++----- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 39424b1fa7..ad46c49357 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -34,6 +34,7 @@ using common::GraphEdge; using common::GraphNode; using InputToNodeMap = std::unordered_map>; +using shape_dict_t = absl::flat_hash_map; std::unordered_set unordered_ops = { "elementwise_add", @@ -54,7 +55,7 @@ std::unordered_set unordered_ops = { "reduce_min", }; -bool IsSameSubexpression(Node* op1, Node* op2, const absl::flat_hash_map& shape_dict) { +bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { auto op1_in_edges = op1->inlinks_in_order(true); auto op2_in_edges = op2->inlinks_in_order(true); auto op1_inputs_size = op1_in_edges.size(); @@ -94,12 +95,34 @@ bool IsSameSubexpression(Node* op1, Node* op2, const absl::flat_hash_mapattrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) { - if (!op2->attrs.attr_store.count(attr.first) || op2->attrs.attr_store[attr.first] != attr.second) { + if (op1->op()->name == "reshape") { + auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as(); + auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as(); + return shape_dict[op1_sink_node->id()] == shape_dict[op2_sink_node->id()]; + } else { + auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as(); + auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as(); + if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) { return false; } - return true; - }); + return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) { + if (!op2->attrs.attr_store.count(attr.first) || op2->attrs.attr_store[attr.first] != attr.second) { + if (attr.first == "axis" || attr.first == "dim") { + auto op1_axis = absl::get(attr.second); + auto op2_axis = absl::get(op2->attrs.attr_store[attr.first]); + if (op1_axis < 0) { + op1_axis += shape_dict[op1_sink_node->id()].size(); + } + if (op2_axis < 0) { + op2_axis += shape_dict[op1_sink_node->id()].size(); + } + return op2_axis == op1_axis; + } + return false; + } + return true; + }); + } } void RemoveNode(framework::Graph* graph, Node* node) { @@ -136,8 +159,8 @@ void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) { int CommonSubexpressionElimination(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { std::unordered_map> expr_map; - auto& shape_dict = graph->GetAttrs>("infershape"); - int remove_num = 0; + auto shape_dict = graph->GetAttrs>("infershape"); + int remove_num = 0; for (auto& graph_node : store_nodes) { auto node = graph_node->safe_as(); if (node) { diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index 133c7d2d7c..3bacb869f2 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -55,8 +55,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { auto add_1 = program.add(A, B); auto add_2 = program.add(B, A); auto add = program.add(add_1, add_2); - auto max_1 = program.reduce_max(add, {-1}, false); - auto max_2 = program.reduce_max(add, {1}, false); + auto t_1 = program.transpose(add, {0, 1}); + auto t_2 = program.transpose(add, {0, 1}); auto max = program.reduce_max(add, {0}, true); Target target = common::DefaultTarget(); @@ -64,7 +64,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { program.Validate(); LOG(INFO) << "Program:\n" << program; auto graph = std::make_shared(program, target); - LOG(INFO) << "graph:\n" << graph->Visualize(); + // LOG(INFO) << "graph:\n" << graph->Visualize(); hlir::framework::ApplyPass(graph.get(), "InferShape"); hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); @@ -75,7 +75,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { auto& prerun_instrs = runtime_program->GetPreRunInstructions(); auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); - ASSERT_EQ(run_instrs.size(), 2); + ASSERT_EQ(run_instrs.size(), 4); scope->Var("A"); scope->Var("B"); @@ -95,8 +95,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { Program program; auto sub_1 = program.elementwise_sub(A, A); auto sub_2 = program.elementwise_sub(A, A); - auto add_1 = program.add(B, sub_1); - auto add_2 = program.add(sub_2, B); + auto add_1 = program.reshape(B, {4, -1}); + auto add_2 = program.reshape(B, {4, 8}); auto add = program.add(add_1, add_2); Target target = common::DefaultTarget(); @@ -104,7 +104,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { program.Validate(); LOG(INFO) << "Program:\n" << program; auto graph = std::make_shared(program, target); - LOG(INFO) << "graph:\n" << graph->Visualize(); + // LOG(INFO) << "graph:\n" << graph->Visualize(); hlir::framework::ApplyPass(graph.get(), "InferShape"); hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); @@ -146,7 +146,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { program.Validate(); LOG(INFO) << "Program:\n" << program; auto graph = std::make_shared(program, target); - LOG(INFO) << "graph:\n" << graph->Visualize(); + // LOG(INFO) << "graph:\n" << graph->Visualize(); hlir::framework::ApplyPass(graph.get(), "InferShape"); hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); @@ -167,7 +167,6 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { SetRandData(B1, target); runtime_program->Execute(); - LOG(INFO) << "graph:\n" << graph->Visualize(); } } // namespace frontend From ddc984a066732ebb864ff2597292c9e8cd3fbfe7 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Sat, 31 Dec 2022 05:44:47 +0000 Subject: [PATCH 12/27] add complex test case --- .../pass/common_subexpression_elimination.cc | 2 +- .../common_subexpression_elimination_test.cc | 47 ++++++++++++------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index ad46c49357..5f2e5a53cc 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -194,7 +194,7 @@ int CommonSubexpressionElimination(Graph* graph, std::vector& store_ } return remove_num; } -// + void CommonSubexpressionEliminationPass(Graph* graph) { VLOG(3) << "CommonSubexpressionEliminationPass...!"; std::unordered_map> expr_map; diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index 3bacb869f2..aa4c5e2223 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -55,8 +55,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { auto add_1 = program.add(A, B); auto add_2 = program.add(B, A); auto add = program.add(add_1, add_2); - auto t_1 = program.transpose(add, {0, 1}); - auto t_2 = program.transpose(add, {0, 1}); + auto t_1 = program.transpose(add, {1, 0}); + auto t_2 = program.transpose(add, {1, 0}); auto max = program.reduce_max(add, {0}, true); Target target = common::DefaultTarget(); @@ -93,11 +93,11 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { Placeholder B(Float(32), {32, 1}, "B", true); Program program; - auto sub_1 = program.elementwise_sub(A, A); - auto sub_2 = program.elementwise_sub(A, A); - auto add_1 = program.reshape(B, {4, -1}); - auto add_2 = program.reshape(B, {4, 8}); - auto add = program.add(add_1, add_2); + auto add_1 = program.add(A, A); + auto add_2 = program.add(A, A); + auto reshape_1 = program.reshape(B, {4, -1}); + auto reshape_2 = program.reshape(B, {4, 8}); + auto add = program.add(reshape_1, reshape_2); Target target = common::DefaultTarget(); program.SetInputs({A, B}); @@ -129,17 +129,30 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { } TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { - Placeholder A(Float(32), {32, 16}, "A"); - Placeholder B(Float(32), {32, 1}, "B", true); + Placeholder A(Float(32), {1, 3, 224, 224}, "A"); + Placeholder B(Float(32), {1, 1, 224, 224}, "B", true); + + absl::flat_hash_map attrs; + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); + std::string src_layout = "NCHW"; + attrs["data_format"] = src_layout; Program program; - auto sub_1 = program.elementwise_sub(A, A); - auto sub_2 = program.elementwise_sub(A, A); - auto const_1 = program.fill_constant({32, 16}, 1.0f, "", false, "const1"); - auto const_2 = program.fill_constant({32, 16}, 1.0f, "", false, "const2"); - auto const_3 = program.fill_constant({32, 16}, 2.0f, "", false, "const3"); - auto out1 = program.add(const_1, const_3); - auto out2 = program.add(const_2, const_3); + auto add_1 = program.add(A, B); + auto weight_1 = program.fill_constant({64, 3, 7, 7}, 1.0f, "", false, "w1"); + auto weight_2 = program.fill_constant({64, 3, 7, 7}, 1.0f, "", false, "w2"); + auto bias = program.fill_constant({1, 64, 112, 112}, 2.0f, "", false, "b1"); + auto conv_1 = program.conv2d(add_1, weight_1, attrs); + auto add_2 = program.add(conv_1, bias); + auto relu_1 = program.relu(add_2); + auto conv_2 = program.conv2d(add_1, weight_2, attrs); + auto add_3 = program.add(conv_2, bias); + auto relu_2 = program.relu(add_3); + auto out1 = program.add(relu_1, add_2); + auto out2 = program.add(add_2, relu_2); + auto out = program.multiply(out1, out2); Target target = common::DefaultTarget(); program.SetInputs({A, B}); @@ -157,7 +170,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { auto& prerun_instrs = runtime_program->GetPreRunInstructions(); auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); - ASSERT_EQ(run_instrs.size(), 4); + ASSERT_EQ(run_instrs.size(), 8); scope->Var("A"); scope->Var("B"); From 8287854c90aa32294c70338fa13f144e465cffb1 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Wed, 4 Jan 2023 09:46:04 +0000 Subject: [PATCH 13/27] fix bug --- cinn/hlir/pass/common_subexpression_elimination_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index aa4c5e2223..5376847a6e 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -128,6 +128,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { runtime_program->Execute(); } +#ifdef CINN_WITH_CUDA TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { Placeholder A(Float(32), {1, 3, 224, 224}, "A"); Placeholder B(Float(32), {1, 1, 224, 224}, "B", true); @@ -181,6 +182,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { runtime_program->Execute(); } +#endif } // namespace frontend } // namespace cinn From 1eadc1e64e88ff906d60e77a30732c64d7ebb76c Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Tue, 10 Jan 2023 11:34:27 +0000 Subject: [PATCH 14/27] set_change_structure true and remove out_node --- cinn/hlir/pass/common_subexpression_elimination.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 5f2e5a53cc..bf74844179 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -49,10 +49,6 @@ std::unordered_set unordered_ops = { "bitwise_or", "bitwise_xor", "bitwise_and", - "reduce_sum", - "reduce_prod", - "reduce_max", - "reduce_min", }; bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { @@ -136,6 +132,7 @@ void RemoveNode(framework::Graph* graph, Node* node) { auto* out_node = edge->sink()->safe_as(); CHECK(out_node); node->UnLinkSingleTo(out_node); + graph->DropNode(out_node); } graph->DropNode(node); LOG(INFO) << "remove " << node->id() << " node."; @@ -225,7 +222,7 @@ void CommonSubexpressionEliminationPass(Graph* graph) { CINN_REGISTER_HELPER(CommonSubexpressionEliminationPass) { CINN_REGISTER_PASS(CommonSubexpressionEliminationPass) .describe("This pass will remove these same sub-expression.") - .set_change_structure(false) + .set_change_structure(true) .set_body(cinn::hlir::pass::CommonSubexpressionEliminationPass); return true; From 63f8dc49aca835bf85b258a2eeaf4a05e1f3210a Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Tue, 10 Jan 2023 13:10:15 +0000 Subject: [PATCH 15/27] fix bug in RemoveNode --- .../pass/common_subexpression_elimination.cc | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index bf74844179..3ef297507e 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -121,21 +121,20 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { } } -void RemoveNode(framework::Graph* graph, Node* node) { +void RemoveNode(framework::Graph* graph, GraphNode* node) { auto in_edges = node->inlinks(); for (auto& edge : in_edges) { - auto* in_node = edge->source()->safe_as(); + auto* in_node = edge->source(); in_node->UnLinkSingleTo(node); } auto out_edges = node->outlinks(); for (auto& edge : out_edges) { - auto* out_node = edge->sink()->safe_as(); + auto* out_node = edge->sink(); CHECK(out_node); node->UnLinkSingleTo(out_node); graph->DropNode(out_node); } graph->DropNode(node); - LOG(INFO) << "remove " << node->id() << " node."; } void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) { @@ -157,7 +156,7 @@ void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) { int CommonSubexpressionElimination(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { std::unordered_map> expr_map; auto shape_dict = graph->GetAttrs>("infershape"); - int remove_num = 0; + std::vector remove_nodes; for (auto& graph_node : store_nodes) { auto node = graph_node->safe_as(); if (node) { @@ -180,8 +179,8 @@ int CommonSubexpressionElimination(Graph* graph, std::vector& store_ out_nodes.insert(candidate_node); } } - RemoveNode(graph, node); - remove_num++; + remove_nodes.push_back(node); + LOG(INFO) << "remove " << node->id() << " node."; break; } if (!found) { @@ -189,7 +188,10 @@ int CommonSubexpressionElimination(Graph* graph, std::vector& store_ } } } - return remove_num; + for (auto node : remove_nodes) { + RemoveNode(graph, node); + } + return remove_nodes.size(); } void CommonSubexpressionEliminationPass(Graph* graph) { From 1441e5de1a806a1d1747a347d5abd131fa9afdc5 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Wed, 11 Jan 2023 09:43:43 +0000 Subject: [PATCH 16/27] the nodes in outputs will not be removed --- cinn/frontend/optimize.cc | 6 +++--- .../hlir/pass/common_subexpression_elimination.cc | 15 ++++++++++++--- .../pass/common_subexpression_elimination_test.cc | 12 +++++++----- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/cinn/frontend/optimize.cc b/cinn/frontend/optimize.cc index 442114928f..68ccd48076 100644 --- a/cinn/frontend/optimize.cc +++ b/cinn/frontend/optimize.cc @@ -84,9 +84,9 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { options.graph_passes.emplace_back("BuildNonFusedGroupsPass"); } - // if (FLAGS_cinn_use_common_subexpression_elimination) { - // options.graph_passes.emplace_back("CommonSubexpressionEliminationPass"); - // } + if (FLAGS_cinn_use_common_subexpression_elimination) { + options.graph_passes.emplace_back("CommonSubexpressionEliminationPass"); + } // WARNING: the pass must be the last pass !!! if (FLAGS_cinn_check_fusion_accuracy_pass) { diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 3ef297507e..51e33f6aef 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -122,6 +122,9 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { } void RemoveNode(framework::Graph* graph, GraphNode* node) { + if (std::count(graph->outputs.begin(), graph->outputs.end(), node)) { + return; + } auto in_edges = node->inlinks(); for (auto& edge : in_edges) { auto* in_node = edge->source(); @@ -153,7 +156,7 @@ void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) { } } -int CommonSubexpressionElimination(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { +size_t CommonSubexpressionElimination(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { std::unordered_map> expr_map; auto shape_dict = graph->GetAttrs>("infershape"); std::vector remove_nodes; @@ -174,7 +177,13 @@ int CommonSubexpressionElimination(Graph* graph, std::vector& store_ CHECK(candidate_sink_node); auto out_nodes = in2node[sink_node->id()]; for (auto out_node : out_nodes) { - ReplaceNode(candidate_sink_node, sink_node, out_node); + if (std::count(graph->outputs.begin(), graph->outputs.end(), sink_node)) { + for (const auto& candidate_sink_in_edge : candidate_sink_node->inlinks()) { + candidate_sink_in_edge->sink()->LinkTo(sink_node); + } + } else { + ReplaceNode(candidate_sink_node, sink_node, out_node); + } out_nodes.erase(node); out_nodes.insert(candidate_node); } @@ -210,7 +219,7 @@ void CommonSubexpressionEliminationPass(Graph* graph) { } } - int remove_num = CommonSubexpressionElimination(graph, store_nodes, in2node); + size_t remove_num = CommonSubexpressionElimination(graph, store_nodes, in2node); while (remove_num) { store_nodes = std::get<0>(graph->topological_order()); remove_num = CommonSubexpressionElimination(graph, store_nodes, in2node); diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index 5376847a6e..2033b73790 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -64,7 +64,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { program.Validate(); LOG(INFO) << "Program:\n" << program; auto graph = std::make_shared(program, target); - // LOG(INFO) << "graph:\n" << graph->Visualize(); + LOG(INFO) << "graph:\n" << graph->Visualize(); hlir::framework::ApplyPass(graph.get(), "InferShape"); hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); @@ -85,6 +85,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { SetRandData(A1, target); SetRandData(B1, target); + LOG(INFO) << "graph:\n" << graph->Visualize(); runtime_program->Execute(); } @@ -104,7 +105,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { program.Validate(); LOG(INFO) << "Program:\n" << program; auto graph = std::make_shared(program, target); - // LOG(INFO) << "graph:\n" << graph->Visualize(); + LOG(INFO) << "graph:\n" << graph->Visualize(); hlir::framework::ApplyPass(graph.get(), "InferShape"); hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); @@ -125,6 +126,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { SetRandData(A1, target); SetRandData(B1, target); + LOG(INFO) << "graph:\n" << graph->Visualize(); runtime_program->Execute(); } @@ -153,14 +155,13 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { auto relu_2 = program.relu(add_3); auto out1 = program.add(relu_1, add_2); auto out2 = program.add(add_2, relu_2); - auto out = program.multiply(out1, out2); Target target = common::DefaultTarget(); program.SetInputs({A, B}); program.Validate(); LOG(INFO) << "Program:\n" << program; auto graph = std::make_shared(program, target); - // LOG(INFO) << "graph:\n" << graph->Visualize(); + LOG(INFO) << "graph:\n" << graph->Visualize(); hlir::framework::ApplyPass(graph.get(), "InferShape"); hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); @@ -171,7 +172,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { auto& prerun_instrs = runtime_program->GetPreRunInstructions(); auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); - ASSERT_EQ(run_instrs.size(), 8); + ASSERT_EQ(run_instrs.size(), 7); scope->Var("A"); scope->Var("B"); @@ -180,6 +181,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { SetRandData(A1, target); SetRandData(B1, target); + LOG(INFO) << "graph:\n" << graph->Visualize(); runtime_program->Execute(); } #endif From ea013d54cce877c256f04ae3514db12fdbd1bdcb Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Wed, 11 Jan 2023 12:17:40 +0000 Subject: [PATCH 17/27] fix bug and add fetch_list test --- cinn/hlir/pass/common_subexpression_elimination.cc | 9 +++++---- cinn/hlir/pass/common_subexpression_elimination_test.cc | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 51e33f6aef..4be2bd3343 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -133,9 +133,7 @@ void RemoveNode(framework::Graph* graph, GraphNode* node) { auto out_edges = node->outlinks(); for (auto& edge : out_edges) { auto* out_node = edge->sink(); - CHECK(out_node); node->UnLinkSingleTo(out_node); - graph->DropNode(out_node); } graph->DropNode(node); } @@ -160,7 +158,7 @@ size_t CommonSubexpressionElimination(Graph* graph, std::vector& sto std::unordered_map> expr_map; auto shape_dict = graph->GetAttrs>("infershape"); std::vector remove_nodes; - for (auto& graph_node : store_nodes) { + for (auto* graph_node : store_nodes) { auto node = graph_node->safe_as(); if (node) { auto& node_type = node->op()->name; @@ -183,6 +181,9 @@ size_t CommonSubexpressionElimination(Graph* graph, std::vector& sto } } else { ReplaceNode(candidate_sink_node, sink_node, out_node); + if (!std::count(remove_nodes.begin(), remove_nodes.end(), sink_node)) { + remove_nodes.push_back(sink_node); + } } out_nodes.erase(node); out_nodes.insert(candidate_node); @@ -197,7 +198,7 @@ size_t CommonSubexpressionElimination(Graph* graph, std::vector& sto } } } - for (auto node : remove_nodes) { + for (auto* node : remove_nodes) { RemoveNode(graph, node); } return remove_nodes.size(); diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index 2033b73790..d6391bd0d4 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -160,7 +160,10 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { program.SetInputs({A, B}); program.Validate(); LOG(INFO) << "Program:\n" << program; - auto graph = std::make_shared(program, target); + std::unordered_set fetch_list; + fetch_list.insert(out1->id); + fetch_list.insert(out2->id); + auto graph = std::make_shared(program, fetch_list, target); LOG(INFO) << "graph:\n" << graph->Visualize(); hlir::framework::ApplyPass(graph.get(), "InferShape"); From cc40d79e5c053a1dd5216891499d60422ae0c6da Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Wed, 11 Jan 2023 16:20:43 +0000 Subject: [PATCH 18/27] fix bug and add code annotation --- .../pass/common_subexpression_elimination.cc | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 4be2bd3343..95453c1b07 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -165,25 +165,28 @@ size_t CommonSubexpressionElimination(Graph* graph, std::vector& sto auto& candidates = expr_map[node_type]; bool found = false; for (auto* candidate_node : candidates) { + // If node is different from candidate_node, continue the next. if (!IsSameSubexpression(node, candidate_node, shape_dict)) continue; found = true; for (int k = 0; k < node->outlinks_in_order(true).size(); ++k) { CHECK(node->outlinks_in_order(true).size() == candidate_node->outlinks_in_order(true).size()); - auto* sink_node = node->outlinks_in_order(true)[k]->sink()->safe_as(); - auto* candidate_sink_node = candidate_node->outlinks_in_order(true)[k]->sink()->safe_as(); + auto* sink_node = node->outlinks_in_order()[k]->sink()->safe_as(); + auto* candidate_sink_node = candidate_node->outlinks_in_order()[k]->sink()->safe_as(); CHECK(sink_node); CHECK(candidate_sink_node); + size_t n_sink_in_outputs = std::count(graph->outputs.begin(), graph->outputs.end(), sink_node); + // If sink node in outputs, the node's source_node will be replaced by candidate_sink_node's source_node. + if (n_sink_in_outputs) { + sink_node->source_node = candidate_sink_node->source_node; + } + // Replace sink_node with candidate_sink_node in nodes linked by sink_node. auto out_nodes = in2node[sink_node->id()]; for (auto out_node : out_nodes) { - if (std::count(graph->outputs.begin(), graph->outputs.end(), sink_node)) { - for (const auto& candidate_sink_in_edge : candidate_sink_node->inlinks()) { - candidate_sink_in_edge->sink()->LinkTo(sink_node); - } - } else { - ReplaceNode(candidate_sink_node, sink_node, out_node); - if (!std::count(remove_nodes.begin(), remove_nodes.end(), sink_node)) { - remove_nodes.push_back(sink_node); - } + ReplaceNode(candidate_sink_node, sink_node, out_node); + // If sink node is not in outputs and not in removes, the node will be removed. + size_t n_sink_in_removes = std::count(remove_nodes.begin(), remove_nodes.end(), sink_node); + if (n_sink_in_removes == 0 && n_sink_in_outputs == 0) { + remove_nodes.push_back(sink_node); } out_nodes.erase(node); out_nodes.insert(candidate_node); From 98edf6040047b0ba9abbaa79ab46b805f91bbbe8 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Wed, 11 Jan 2023 17:41:29 +0000 Subject: [PATCH 19/27] fix bug --- .../pass/common_subexpression_elimination.cc | 73 ++++++++++--------- cinn/runtime/flags.cc | 2 +- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 95453c1b07..2a3006264c 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -18,6 +18,7 @@ #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_lowering.h" #include "cinn/hlir/framework/pass.h" #include "cinn/hlir/pass/use_pass.h" #include "cinn/utils/string.h" @@ -121,48 +122,54 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { } } -void RemoveNode(framework::Graph* graph, GraphNode* node) { - if (std::count(graph->outputs.begin(), graph->outputs.end(), node)) { - return; - } - auto in_edges = node->inlinks(); - for (auto& edge : in_edges) { - auto* in_node = edge->source(); - in_node->UnLinkSingleTo(node); +void RemoveNodes(framework::Graph* graph, std::vector& nodes) { + for (auto* node : nodes) { + auto in_edges = node->inlinks(); + for (auto& edge : in_edges) { + auto* in_node = edge->source(); + in_node->UnLinkSingleTo(node); + } + auto out_edges = node->outlinks(); + for (auto& edge : out_edges) { + auto* out_node = edge->sink(); + node->UnLinkSingleTo(out_node); + } + graph->DropNode(node); } - auto out_edges = node->outlinks(); - for (auto& edge : out_edges) { - auto* out_node = edge->sink(); - node->UnLinkSingleTo(out_node); +} + +void RemoveNodes(framework::Graph* graph, std::vector& nodes_data) { + for (auto* data : nodes_data) { + if (std::count(graph->outputs.begin(), graph->outputs.end(), data)) { + return; + } + graph->DropNode(data); } - graph->DropNode(node); } void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) { std::vector in_nodes; for (auto& in_edge : trt->inlinks_in_order(true)) { auto* in_node = in_edge->source()->safe_as(); + in_node->UnLinkSingleTo(trt); if (in_node->id() == src_old->id()) { - in_nodes.emplace_back(src_new); + src_new->LinkTo(trt); } else { - in_nodes.emplace_back(in_node); + in_node->LinkTo(trt); } - in_node->UnLinkSingleTo(trt); - } - for (auto in_node : in_nodes) { - in_node->LinkTo(trt); } } size_t CommonSubexpressionElimination(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { - std::unordered_map> expr_map; + std::unordered_map> candidates_map; auto shape_dict = graph->GetAttrs>("infershape"); - std::vector remove_nodes; + std::vector remove_nodes; + std::vector remove_nodes_data; for (auto* graph_node : store_nodes) { auto node = graph_node->safe_as(); if (node) { auto& node_type = node->op()->name; - auto& candidates = expr_map[node_type]; + auto& candidates = candidates_map[node_type]; bool found = false; for (auto* candidate_node : candidates) { // If node is different from candidate_node, continue the next. @@ -175,19 +182,19 @@ size_t CommonSubexpressionElimination(Graph* graph, std::vector& sto CHECK(sink_node); CHECK(candidate_sink_node); size_t n_sink_in_outputs = std::count(graph->outputs.begin(), graph->outputs.end(), sink_node); - // If sink node in outputs, the node's source_node will be replaced by candidate_sink_node's source_node. if (n_sink_in_outputs) { + // If sink node in outputs, the node's source_node will be replaced by candidate_sink_node's source_node. + node->UnLinkSingleTo(sink_node); sink_node->source_node = candidate_sink_node->source_node; + candidate_sink_node->source_node->LinkTo(sink_node); + } else { + // If sink node not in outputs, the node will be removed. + remove_nodes_data.push_back(sink_node); } // Replace sink_node with candidate_sink_node in nodes linked by sink_node. auto out_nodes = in2node[sink_node->id()]; for (auto out_node : out_nodes) { ReplaceNode(candidate_sink_node, sink_node, out_node); - // If sink node is not in outputs and not in removes, the node will be removed. - size_t n_sink_in_removes = std::count(remove_nodes.begin(), remove_nodes.end(), sink_node); - if (n_sink_in_removes == 0 && n_sink_in_outputs == 0) { - remove_nodes.push_back(sink_node); - } out_nodes.erase(node); out_nodes.insert(candidate_node); } @@ -197,19 +204,19 @@ size_t CommonSubexpressionElimination(Graph* graph, std::vector& sto break; } if (!found) { - expr_map[node_type].push_back(node); + candidates_map[node_type].push_back(node); } } } - for (auto* node : remove_nodes) { - RemoveNode(graph, node); - } + // Node should be deleted before node data. + RemoveNodes(graph, remove_nodes); + RemoveNodes(graph, remove_nodes_data); return remove_nodes.size(); } void CommonSubexpressionEliminationPass(Graph* graph) { VLOG(3) << "CommonSubexpressionEliminationPass...!"; - std::unordered_map> expr_map; + std::unordered_map> candidates_map; InputToNodeMap in2node; auto store_nodes = std::get<0>(graph->topological_order()); diff --git a/cinn/runtime/flags.cc b/cinn/runtime/flags.cc index be93eb76cf..2efd0b9a4b 100644 --- a/cinn/runtime/flags.cc +++ b/cinn/runtime/flags.cc @@ -44,7 +44,7 @@ DEFINE_bool(cinn_use_cudnn_conv, BoolFromEnv("FLAGS_cinn_use_cudnn_conv", true), DEFINE_bool(cinn_use_cublas_gemm, BoolFromEnv("FLAGS_cinn_use_cublas_gemm", true), "Whether to use cublas gemm."); DEFINE_bool(cinn_use_common_subexpression_elimination, - BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", true), + BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", false), "Whether to use common subexpression elimination pass."); DEFINE_bool(cinn_use_fill_constant_folding, From 2b77af07ce25b5a7d94616bdd5a602d1519008d0 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Thu, 12 Jan 2023 08:38:35 +0000 Subject: [PATCH 20/27] optimization --- .../pass/common_subexpression_elimination.cc | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 2a3006264c..d5b601b338 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -53,19 +53,26 @@ std::unordered_set unordered_ops = { }; bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { - auto op1_in_edges = op1->inlinks_in_order(true); - auto op2_in_edges = op2->inlinks_in_order(true); + // Get the input edges for op1 and op2 in order. + auto op1_in_edges = op1->inlinks_in_order(true); + auto op2_in_edges = op2->inlinks_in_order(true); + // Get the number of input edges for op1 and op2 auto op1_inputs_size = op1_in_edges.size(); auto op2_inputs_size = op2_in_edges.size(); + // If the number of input edges is not the same, the subexpression is not the same. if (op1_inputs_size != op2_inputs_size) { return false; } + // Get the number of attributes for op1 and op2. auto op1_attrs_size = op1->attrs.attr_store.size(); auto op2_attrs_size = op2->attrs.attr_store.size(); + // If the number of attributes is not the same, the subexpression is not the same. if (op1_attrs_size != op2_attrs_size) { return false; } + // Check if the input nodes match. if (unordered_ops.count(op1->op()->name)) { + // For unordered ops, check if any input node of op2 matches any input node of op1. for (auto& op1_edge : op1_in_edges) { auto* op1_source_node = op1_edge->source()->safe_as(); CHECK(op1_source_node); @@ -82,6 +89,7 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { } } } else { + // For ordered ops, check if the input nodes match one-to-one. for (int i = 0; i < op1_inputs_size; ++i) { auto* op1_source_node = op1_in_edges[i]->source()->safe_as(); auto* op2_source_node = op2_in_edges[i]->source()->safe_as(); @@ -92,11 +100,14 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { } } } + if (op1->op()->name == "reshape") { + // For reshape ops, check if the reshaped shape is the same. auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as(); auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as(); return shape_dict[op1_sink_node->id()] == shape_dict[op2_sink_node->id()]; } else { + // For non-reshape ops, check if the number of dimensions and attributes. auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as(); auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as(); if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) { @@ -140,7 +151,7 @@ void RemoveNodes(framework::Graph* graph, std::vector& nodes) { void RemoveNodes(framework::Graph* graph, std::vector& nodes_data) { for (auto* data : nodes_data) { - if (std::count(graph->outputs.begin(), graph->outputs.end(), data)) { + if (std::find(graph->outputs.begin(), graph->outputs.end(), data) != graph->outputs.end()) { return; } graph->DropNode(data); @@ -181,8 +192,7 @@ size_t CommonSubexpressionElimination(Graph* graph, std::vector& sto auto* candidate_sink_node = candidate_node->outlinks_in_order()[k]->sink()->safe_as(); CHECK(sink_node); CHECK(candidate_sink_node); - size_t n_sink_in_outputs = std::count(graph->outputs.begin(), graph->outputs.end(), sink_node); - if (n_sink_in_outputs) { + if (std::find(graph->outputs.begin(), graph->outputs.end(), sink_node) != graph->outputs.end()) { // If sink node in outputs, the node's source_node will be replaced by candidate_sink_node's source_node. node->UnLinkSingleTo(sink_node); sink_node->source_node = candidate_sink_node->source_node; From 5f77902bee76d270d12ce66d50cc15c542fcd375 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Fri, 13 Jan 2023 06:07:02 +0000 Subject: [PATCH 21/27] optimization --- .../pass/common_subexpression_elimination.cc | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index d5b601b338..d779c1900d 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -171,12 +171,15 @@ void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) { } } -size_t CommonSubexpressionElimination(Graph* graph, std::vector& store_nodes, InputToNodeMap in2node) { +void CommonSubexpressionElimination(Graph* graph, std::vector store_nodes, InputToNodeMap in2node) { std::unordered_map> candidates_map; auto shape_dict = graph->GetAttrs>("infershape"); std::vector remove_nodes; std::vector remove_nodes_data; - for (auto* graph_node : store_nodes) { + + while (!store_nodes.empty()) { + auto* graph_node = store_nodes[0]; + store_nodes.erase(store_nodes.begin()); auto node = graph_node->safe_as(); if (node) { auto& node_type = node->op()->name; @@ -192,19 +195,23 @@ size_t CommonSubexpressionElimination(Graph* graph, std::vector& sto auto* candidate_sink_node = candidate_node->outlinks_in_order()[k]->sink()->safe_as(); CHECK(sink_node); CHECK(candidate_sink_node); - if (std::find(graph->outputs.begin(), graph->outputs.end(), sink_node) != graph->outputs.end()) { - // If sink node in outputs, the node's source_node will be replaced by candidate_sink_node's source_node. - node->UnLinkSingleTo(sink_node); - sink_node->source_node = candidate_sink_node->source_node; - candidate_sink_node->source_node->LinkTo(sink_node); - } else { - // If sink node not in outputs, the node will be removed. - remove_nodes_data.push_back(sink_node); + remove_nodes_data.push_back(sink_node); + auto iter_sink_node = std::find(graph->outputs.begin(), graph->outputs.end(), sink_node); + if (iter_sink_node != graph->outputs.end()) { + // If sink node in outputs, the node cannot be removed. + NodeData new_sink_node( + candidate_node, sink_node->output_index, sink_node->version, sink_node->id(), sink_node->is_const()); + graph->outputs.erase(iter_sink_node); + graph->outputs.push_back(&new_sink_node); } // Replace sink_node with candidate_sink_node in nodes linked by sink_node. auto out_nodes = in2node[sink_node->id()]; for (auto out_node : out_nodes) { ReplaceNode(candidate_sink_node, sink_node, out_node); + // The changed out node will be detected again. + if (std::find(store_nodes.begin(), store_nodes.end(), out_node) == store_nodes.end()) { + store_nodes.insert(store_nodes.begin(), out_node); + } out_nodes.erase(node); out_nodes.insert(candidate_node); } @@ -221,7 +228,6 @@ size_t CommonSubexpressionElimination(Graph* graph, std::vector& sto // Node should be deleted before node data. RemoveNodes(graph, remove_nodes); RemoveNodes(graph, remove_nodes_data); - return remove_nodes.size(); } void CommonSubexpressionEliminationPass(Graph* graph) { @@ -240,11 +246,7 @@ void CommonSubexpressionEliminationPass(Graph* graph) { } } - size_t remove_num = CommonSubexpressionElimination(graph, store_nodes, in2node); - while (remove_num) { - store_nodes = std::get<0>(graph->topological_order()); - remove_num = CommonSubexpressionElimination(graph, store_nodes, in2node); - } + CommonSubexpressionElimination(graph, store_nodes, in2node); VLOG(3) << "CommonSubexpressionEliminationPass Finish...!"; } } // namespace pass From c10913470a5b05e6be0c19d29e741d491f2f2265 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Mon, 16 Jan 2023 05:32:13 +0000 Subject: [PATCH 22/27] replace the reshape_op with the operator set --- .../pass/common_subexpression_elimination.cc | 23 +++++++++++-------- .../common_subexpression_elimination_test.cc | 2 +- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index d779c1900d..5608847bfe 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -52,6 +52,12 @@ std::unordered_set unordered_ops = { "bitwise_and", }; +// When all the inputs are the same, those ops just ensure that all the outputs shape is the same. +std::unordered_set reshape_ops = { + "reshape", + "concat", +}; + bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { // Get the input edges for op1 and op2 in order. auto op1_in_edges = op1->inlinks_in_order(true); @@ -101,18 +107,17 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { } } - if (op1->op()->name == "reshape") { + // Check if the number of dimensions. + auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as(); + auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as(); + if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) { + return false; + } + if (reshape_ops.count(op1->op()->name)) { // For reshape ops, check if the reshaped shape is the same. - auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as(); - auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as(); return shape_dict[op1_sink_node->id()] == shape_dict[op2_sink_node->id()]; } else { - // For non-reshape ops, check if the number of dimensions and attributes. - auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as(); - auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as(); - if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) { - return false; - } + // For non-reshape ops, attributes. return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) { if (!op2->attrs.attr_store.count(attr.first) || op2->attrs.attr_store[attr.first] != attr.second) { if (attr.first == "axis" || attr.first == "dim") { diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index d6391bd0d4..a65cc30594 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -98,7 +98,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { auto add_2 = program.add(A, A); auto reshape_1 = program.reshape(B, {4, -1}); auto reshape_2 = program.reshape(B, {4, 8}); - auto add = program.add(reshape_1, reshape_2); + auto add = program.concat({reshape_1, reshape_2}); Target target = common::DefaultTarget(); program.SetInputs({A, B}); From 2c5122a042d8011127f7bb16bb37761e21691e53 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Mon, 16 Jan 2023 05:44:45 +0000 Subject: [PATCH 23/27] use GetNodeData and add tests --- cinn/hlir/pass/common_subexpression_elimination.cc | 4 ++-- cinn/hlir/pass/common_subexpression_elimination_test.cc | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 5608847bfe..2a5c1ede8e 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -108,8 +108,8 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { } // Check if the number of dimensions. - auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as(); - auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as(); + auto* op1_sink_node = GetNodeData(op1); + auto* op2_sink_node = GetNodeData(op2); if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) { return false; } diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index a65cc30594..8e9ba6f56e 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -57,6 +57,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { auto add = program.add(add_1, add_2); auto t_1 = program.transpose(add, {1, 0}); auto t_2 = program.transpose(add, {1, 0}); + auto t_3 = program.transpose(add, {0, 1}); auto max = program.reduce_max(add, {0}, true); Target target = common::DefaultTarget(); @@ -75,7 +76,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { auto& prerun_instrs = runtime_program->GetPreRunInstructions(); auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); - ASSERT_EQ(run_instrs.size(), 4); + ASSERT_EQ(run_instrs.size(), 5); scope->Var("A"); scope->Var("B"); @@ -98,7 +99,9 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { auto add_2 = program.add(A, A); auto reshape_1 = program.reshape(B, {4, -1}); auto reshape_2 = program.reshape(B, {4, 8}); - auto add = program.concat({reshape_1, reshape_2}); + auto concat_1 = program.concat({reshape_1, reshape_2}); + auto concat_2 = program.concat({reshape_1, reshape_2}); + auto concat_3 = program.concat({reshape_1, reshape_2}, 1); Target target = common::DefaultTarget(); program.SetInputs({A, B}); @@ -116,7 +119,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { auto& prerun_instrs = runtime_program->GetPreRunInstructions(); auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); - ASSERT_EQ(run_instrs.size(), 3); + ASSERT_EQ(run_instrs.size(), 4); scope->Var("A"); scope->Var("B"); From fc1aa47b94a6beb5140b1df89f50b8e8c7b53c82 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Mon, 16 Jan 2023 07:39:30 +0000 Subject: [PATCH 24/27] replace the dim and axis with the special_attrs --- .../pass/common_subexpression_elimination.cc | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 2a5c1ede8e..c60c1882ed 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -58,6 +58,11 @@ std::unordered_set reshape_ops = { "concat", }; +// Those special attrs maybe different but equivalent. +std::unordered_set special_attrs = { + "dim" + "axis"}; + bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { // Get the input edges for op1 and op2 in order. auto op1_in_edges = op1->inlinks_in_order(true); @@ -107,7 +112,7 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { } } - // Check if the number of dimensions. + // Check if the number of dimensions is the same. auto* op1_sink_node = GetNodeData(op1); auto* op2_sink_node = GetNodeData(op2); if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) { @@ -117,23 +122,26 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { // For reshape ops, check if the reshaped shape is the same. return shape_dict[op1_sink_node->id()] == shape_dict[op2_sink_node->id()]; } else { - // For non-reshape ops, attributes. + // For non-reshape ops, check if the attributes is the same. return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) { - if (!op2->attrs.attr_store.count(attr.first) || op2->attrs.attr_store[attr.first] != attr.second) { - if (attr.first == "axis" || attr.first == "dim") { - auto op1_axis = absl::get(attr.second); - auto op2_axis = absl::get(op2->attrs.attr_store[attr.first]); - if (op1_axis < 0) { - op1_axis += shape_dict[op1_sink_node->id()].size(); - } - if (op2_axis < 0) { - op2_axis += shape_dict[op1_sink_node->id()].size(); - } - return op2_axis == op1_axis; - } + if (!op2->attrs.attr_store.count(attr.first)) { return false; } - return true; + auto& attr1 = attr.second; + auto& attr2 = op2->attrs.attr_store[attr.first]; + auto ndim = shape_dict[op1_sink_node->id()].size(); + if (special_attrs.count(attr.first)) { + auto op1_axis = absl::get(attr1); + auto op2_axis = absl::get(attr2); + if (op1_axis < 0) { + op1_axis += ndim; + } + if (op2_axis < 0) { + op2_axis += ndim; + } + return op2_axis == op1_axis; + } + return attr1 == attr2; }); } } From c4bf7e823baa33f5fe03a7d94b8468f060757784 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Mon, 16 Jan 2023 07:45:29 +0000 Subject: [PATCH 25/27] add interface --- .../pass/common_subexpression_elimination.cc | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index c60c1882ed..bb6711b3d8 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -59,9 +59,7 @@ std::unordered_set reshape_ops = { }; // Those special attrs maybe different but equivalent. -std::unordered_set special_attrs = { - "dim" - "axis"}; +std::unordered_map special_attrs = {{"dim", 1}, {"axis", 1}}; bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { // Get the input edges for op1 and op2 in order. @@ -131,15 +129,18 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { auto& attr2 = op2->attrs.attr_store[attr.first]; auto ndim = shape_dict[op1_sink_node->id()].size(); if (special_attrs.count(attr.first)) { - auto op1_axis = absl::get(attr1); - auto op2_axis = absl::get(attr2); - if (op1_axis < 0) { - op1_axis += ndim; - } - if (op2_axis < 0) { - op2_axis += ndim; + switch (special_attrs[attr.first]) { + case 1: + auto op1_axis = absl::get(attr1); + auto op2_axis = absl::get(attr2); + if (op1_axis < 0) { + op1_axis += ndim; + } + if (op2_axis < 0) { + op2_axis += ndim; + } + return op2_axis == op1_axis; } - return op2_axis == op1_axis; } return attr1 == attr2; }); From 3f8514d06adaf72da39031deba66dc33374185f4 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Tue, 17 Jan 2023 04:16:31 +0000 Subject: [PATCH 26/27] fix bug --- .../pass/common_subexpression_elimination.cc | 34 +++++++++++++++++-- .../common_subexpression_elimination_test.cc | 24 ++++++------- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index bb6711b3d8..9c6001e0fc 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -59,7 +59,11 @@ std::unordered_set reshape_ops = { }; // Those special attrs maybe different but equivalent. -std::unordered_map special_attrs = {{"dim", 1}, {"axis", 1}}; +std::unordered_map special_attrs = { + // {"axis", 1}, // due to the issue in some ops + // {"dim", 1}, // due to the issue in some ops + {"axes", 2}, + {"perm", 2}}; bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { // Get the input edges for op1 and op2 in order. @@ -127,10 +131,10 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { } auto& attr1 = attr.second; auto& attr2 = op2->attrs.attr_store[attr.first]; - auto ndim = shape_dict[op1_sink_node->id()].size(); + auto ndim = static_cast(shape_dict[op1_sink_node->id()].size()); if (special_attrs.count(attr.first)) { switch (special_attrs[attr.first]) { - case 1: + case 1: { auto op1_axis = absl::get(attr1); auto op2_axis = absl::get(attr2); if (op1_axis < 0) { @@ -140,6 +144,30 @@ bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { op2_axis += ndim; } return op2_axis == op1_axis; + } + case 2: { + auto& op1_axes = absl::get>(attr1); + auto& op2_axes = absl::get>(attr2); + auto op1_size = op1_axes.size(); + auto op2_size = op2_axes.size(); + if (op1_size != op2_size) { + return false; + } + for (int i = 0; i < op1_axes.size(); ++i) { + int op1_axis = op1_axes[i]; + int op2_axis = op2_axes[i]; + if (op1_axis < 0) { + op1_axis += ndim; + } + if (op2_axis < 0) { + op2_axis += ndim; + } + if (op2_axis != op1_axis) { + return false; + } + } + return true; + } } } return attr1 == attr2; diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc index 8e9ba6f56e..f27cdb483f 100644 --- a/cinn/hlir/pass/common_subexpression_elimination_test.cc +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -30,13 +30,10 @@ #include -#include "cinn/cinn.h" #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/pass.h" -#include "cinn/hlir/op/use_ops.h" -#include "cinn/hlir/pass/use_pass.h" #include "cinn/utils/data_util.h" DEFINE_string(model_dir, "", ""); @@ -48,17 +45,18 @@ using hlir::framework::Scope; using utils::Join; TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { - Placeholder A(Float(32), {32, 16}, "A"); - Placeholder B(Float(32), {32, 1}, "B", true); + Placeholder A(Float(32), {32, 16, 1}, "A"); + Placeholder B(Float(32), {32, 1, 1}, "B", true); Program program; - auto add_1 = program.add(A, B); - auto add_2 = program.add(B, A); - auto add = program.add(add_1, add_2); - auto t_1 = program.transpose(add, {1, 0}); - auto t_2 = program.transpose(add, {1, 0}); - auto t_3 = program.transpose(add, {0, 1}); - auto max = program.reduce_max(add, {0}, true); + auto add_1 = program.add(A, B); + auto add_2 = program.add(B, A); + auto add = program.add(add_1, add_2); + auto t_1 = program.transpose(add, {2, 1, 0}); + auto t_2 = program.transpose(add, {2, 1, 0}); + auto t_3 = program.transpose(add, {2, 0, 1}); + auto concat = program.concat({t_1, t_2, t_3}); + auto max = program.reduce_max(concat, {0}, true); Target target = common::DefaultTarget(); program.SetInputs({A, B}); @@ -76,7 +74,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { auto& prerun_instrs = runtime_program->GetPreRunInstructions(); auto& run_instrs = runtime_program->GetRunInstructions(); ASSERT_EQ(prerun_instrs.size(), 0); - ASSERT_EQ(run_instrs.size(), 5); + ASSERT_EQ(run_instrs.size(), 6); scope->Var("A"); scope->Var("B"); From e24d0b1c4fbfca4dfbbd03e7cdf7e7b0a7661130 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Mon, 30 Jan 2023 05:10:56 +0000 Subject: [PATCH 27/27] fix test case random error --- cinn/hlir/pass/common_subexpression_elimination.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 9c6001e0fc..c7b4ea6705 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -222,12 +222,15 @@ void CommonSubexpressionElimination(Graph* graph, std::vector store_ while (!store_nodes.empty()) { auto* graph_node = store_nodes[0]; store_nodes.erase(store_nodes.begin()); + LOG(INFO) << "size of store_nodes is " << store_nodes.size(); auto node = graph_node->safe_as(); if (node) { auto& node_type = node->op()->name; auto& candidates = candidates_map[node_type]; bool found = false; for (auto* candidate_node : candidates) { + // If node is same with candidate_node, continue the next. + if (node->id() == candidate_node->id()) continue; // If node is different from candidate_node, continue the next. if (!IsSameSubexpression(node, candidate_node, shape_dict)) continue; found = true; @@ -254,8 +257,6 @@ void CommonSubexpressionElimination(Graph* graph, std::vector store_ if (std::find(store_nodes.begin(), store_nodes.end(), out_node) == store_nodes.end()) { store_nodes.insert(store_nodes.begin(), out_node); } - out_nodes.erase(node); - out_nodes.insert(candidate_node); } } remove_nodes.push_back(node);