diff --git a/cinn/frontend/optimize.cc b/cinn/frontend/optimize.cc index ca5de36082..b1ad7b4050 100644 --- a/cinn/frontend/optimize.cc +++ b/cinn/frontend/optimize.cc @@ -30,6 +30,7 @@ DECLARE_bool(cinn_use_fill_constant_folding); DECLARE_bool(cinn_use_op_fusion); DECLARE_bool(cinn_use_cublas_gemm); +DECLARE_bool(cinn_use_common_subexpression_elimination); DECLARE_bool(cinn_check_fusion_accuracy_pass); DECLARE_bool(cinn_use_custom_call); @@ -63,22 +64,21 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { options.program_passes.emplace_back("DeadCodeEliminate"); options.graph_passes = {"ConstantFolding"}; -#ifdef CINN_WITH_CUDA - if (FLAGS_cinn_use_cublas_gemm) { - // options.graph_passes.push_back("DenseMergePass"); - options.graph_passes.push_back("TransToCustomCallPass"); - } -#endif + // options.graph_passes.push_back("DenseMergePass"); if (FLAGS_cinn_use_custom_call) { options.graph_passes.emplace_back("TransToCustomCallPass"); } if (FLAGS_cinn_use_op_fusion) { - options.graph_passes.push_back("OpFusionPass"); - options.graph_passes.push_back("FusionMergePass"); + options.graph_passes.emplace_back("OpFusionPass"); + options.graph_passes.emplace_back("FusionMergePass"); } else { - options.graph_passes.push_back("BuildNonFusedGroupsPass"); + options.graph_passes.emplace_back("BuildNonFusedGroupsPass"); + } + + if (FLAGS_cinn_use_common_subexpression_elimination) { + options.graph_passes.emplace_back("CommonSubexpressionEliminationPass"); } // WARNING: the pass must be the last pass !!! @@ -87,7 +87,6 @@ OptimizeOptions DefaultTrainingOptimizeOptions() { // error and exited. options.graph_passes.emplace_back("CheckFusionAccuracyPass"); } - return options; } diff --git a/cinn/hlir/pass/CMakeLists.txt b/cinn/hlir/pass/CMakeLists.txt old mode 100755 new mode 100644 index 47b363cef1..41aa667be7 --- a/cinn/hlir/pass/CMakeLists.txt +++ b/cinn/hlir/pass/CMakeLists.txt @@ -11,6 +11,7 @@ gather_srcs(cinnapi_src SRCS dot_merger.cc check_fusion_accuracy_pass.cc custom_call_pass.cc + common_subexpression_elimination.cc constant_folding_pass.cc dce_pass.cc dense_merge_pass.cc @@ -33,4 +34,5 @@ if (NOT WITH_CUDA) endif() cc_test(test_dot_merger SRCS test_dot_merger.cc DEPS cinncore) cc_test(test_dce_pass SRCS dce_pass_test.cc DEPS cinncore) +cc_test(test_common_subexpression_elimination SRCS common_subexpression_elimination_test.cc DEPS cinncore) cc_test(test_constant_folding_pass SRCS constant_folding_pass_test.cc DEPS cinncore) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc new file mode 100644 index 0000000000..c7b4ea6705 --- /dev/null +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -0,0 +1,306 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_lowering.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/hlir/pass/use_pass.h" +#include "cinn/utils/string.h" + +namespace cinn { +namespace hlir { +namespace pass { + +using framework::Graph; +using framework::Node; +using framework::NodeData; + +using common::GraphEdge; +using common::GraphNode; + +using InputToNodeMap = std::unordered_map>; +using shape_dict_t = absl::flat_hash_map; + +std::unordered_set unordered_ops = { + "elementwise_add", + "elementwise_mul", + "max", + "min", + "logical_and", + "logical_or", + "logical_xor", + "equal", + "not_equal", + "bitwise_or", + "bitwise_xor", + "bitwise_and", +}; + +// When all the inputs are the same, those ops just ensure that all the outputs shape is the same. +std::unordered_set reshape_ops = { + "reshape", + "concat", +}; + +// Those special attrs maybe different but equivalent. +std::unordered_map special_attrs = { + // {"axis", 1}, // due to the issue in some ops + // {"dim", 1}, // due to the issue in some ops + {"axes", 2}, + {"perm", 2}}; + +bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) { + // Get the input edges for op1 and op2 in order. + auto op1_in_edges = op1->inlinks_in_order(true); + auto op2_in_edges = op2->inlinks_in_order(true); + // Get the number of input edges for op1 and op2 + auto op1_inputs_size = op1_in_edges.size(); + auto op2_inputs_size = op2_in_edges.size(); + // If the number of input edges is not the same, the subexpression is not the same. + if (op1_inputs_size != op2_inputs_size) { + return false; + } + // Get the number of attributes for op1 and op2. + auto op1_attrs_size = op1->attrs.attr_store.size(); + auto op2_attrs_size = op2->attrs.attr_store.size(); + // If the number of attributes is not the same, the subexpression is not the same. + if (op1_attrs_size != op2_attrs_size) { + return false; + } + // Check if the input nodes match. + if (unordered_ops.count(op1->op()->name)) { + // For unordered ops, check if any input node of op2 matches any input node of op1. + for (auto& op1_edge : op1_in_edges) { + auto* op1_source_node = op1_edge->source()->safe_as(); + CHECK(op1_source_node); + bool op1_equal_op2 = std::any_of(op2_in_edges.begin(), op2_in_edges.end(), [&](common::Shared& edge) { + auto* op2_source_node = edge->source()->safe_as(); + CHECK(op2_source_node); + if (op1_source_node->id() == op2_source_node->id()) { + return true; + } + return false; + }); + if (!op1_equal_op2) { + return false; + } + } + } else { + // For ordered ops, check if the input nodes match one-to-one. + for (int i = 0; i < op1_inputs_size; ++i) { + auto* op1_source_node = op1_in_edges[i]->source()->safe_as(); + auto* op2_source_node = op2_in_edges[i]->source()->safe_as(); + CHECK(op1_source_node); + CHECK(op2_source_node); + if (op1_source_node->id() != op2_source_node->id()) { + return false; + } + } + } + + // Check if the number of dimensions is the same. + auto* op1_sink_node = GetNodeData(op1); + auto* op2_sink_node = GetNodeData(op2); + if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) { + return false; + } + if (reshape_ops.count(op1->op()->name)) { + // For reshape ops, check if the reshaped shape is the same. + return shape_dict[op1_sink_node->id()] == shape_dict[op2_sink_node->id()]; + } else { + // For non-reshape ops, check if the attributes is the same. + return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) { + if (!op2->attrs.attr_store.count(attr.first)) { + return false; + } + auto& attr1 = attr.second; + auto& attr2 = op2->attrs.attr_store[attr.first]; + auto ndim = static_cast(shape_dict[op1_sink_node->id()].size()); + if (special_attrs.count(attr.first)) { + switch (special_attrs[attr.first]) { + case 1: { + auto op1_axis = absl::get(attr1); + auto op2_axis = absl::get(attr2); + if (op1_axis < 0) { + op1_axis += ndim; + } + if (op2_axis < 0) { + op2_axis += ndim; + } + return op2_axis == op1_axis; + } + case 2: { + auto& op1_axes = absl::get>(attr1); + auto& op2_axes = absl::get>(attr2); + auto op1_size = op1_axes.size(); + auto op2_size = op2_axes.size(); + if (op1_size != op2_size) { + return false; + } + for (int i = 0; i < op1_axes.size(); ++i) { + int op1_axis = op1_axes[i]; + int op2_axis = op2_axes[i]; + if (op1_axis < 0) { + op1_axis += ndim; + } + if (op2_axis < 0) { + op2_axis += ndim; + } + if (op2_axis != op1_axis) { + return false; + } + } + return true; + } + } + } + return attr1 == attr2; + }); + } +} + +void RemoveNodes(framework::Graph* graph, std::vector& nodes) { + for (auto* node : nodes) { + auto in_edges = node->inlinks(); + for (auto& edge : in_edges) { + auto* in_node = edge->source(); + in_node->UnLinkSingleTo(node); + } + auto out_edges = node->outlinks(); + for (auto& edge : out_edges) { + auto* out_node = edge->sink(); + node->UnLinkSingleTo(out_node); + } + graph->DropNode(node); + } +} + +void RemoveNodes(framework::Graph* graph, std::vector& nodes_data) { + for (auto* data : nodes_data) { + if (std::find(graph->outputs.begin(), graph->outputs.end(), data) != graph->outputs.end()) { + return; + } + graph->DropNode(data); + } +} + +void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) { + std::vector in_nodes; + for (auto& in_edge : trt->inlinks_in_order(true)) { + auto* in_node = in_edge->source()->safe_as(); + in_node->UnLinkSingleTo(trt); + if (in_node->id() == src_old->id()) { + src_new->LinkTo(trt); + } else { + in_node->LinkTo(trt); + } + } +} + +void CommonSubexpressionElimination(Graph* graph, std::vector store_nodes, InputToNodeMap in2node) { + std::unordered_map> candidates_map; + auto shape_dict = graph->GetAttrs>("infershape"); + std::vector remove_nodes; + std::vector remove_nodes_data; + + while (!store_nodes.empty()) { + auto* graph_node = store_nodes[0]; + store_nodes.erase(store_nodes.begin()); + LOG(INFO) << "size of store_nodes is " << store_nodes.size(); + auto node = graph_node->safe_as(); + if (node) { + auto& node_type = node->op()->name; + auto& candidates = candidates_map[node_type]; + bool found = false; + for (auto* candidate_node : candidates) { + // If node is same with candidate_node, continue the next. + if (node->id() == candidate_node->id()) continue; + // If node is different from candidate_node, continue the next. + if (!IsSameSubexpression(node, candidate_node, shape_dict)) continue; + found = true; + for (int k = 0; k < node->outlinks_in_order(true).size(); ++k) { + CHECK(node->outlinks_in_order(true).size() == candidate_node->outlinks_in_order(true).size()); + auto* sink_node = node->outlinks_in_order()[k]->sink()->safe_as(); + auto* candidate_sink_node = candidate_node->outlinks_in_order()[k]->sink()->safe_as(); + CHECK(sink_node); + CHECK(candidate_sink_node); + remove_nodes_data.push_back(sink_node); + auto iter_sink_node = std::find(graph->outputs.begin(), graph->outputs.end(), sink_node); + if (iter_sink_node != graph->outputs.end()) { + // If sink node in outputs, the node cannot be removed. + NodeData new_sink_node( + candidate_node, sink_node->output_index, sink_node->version, sink_node->id(), sink_node->is_const()); + graph->outputs.erase(iter_sink_node); + graph->outputs.push_back(&new_sink_node); + } + // Replace sink_node with candidate_sink_node in nodes linked by sink_node. + auto out_nodes = in2node[sink_node->id()]; + for (auto out_node : out_nodes) { + ReplaceNode(candidate_sink_node, sink_node, out_node); + // The changed out node will be detected again. + if (std::find(store_nodes.begin(), store_nodes.end(), out_node) == store_nodes.end()) { + store_nodes.insert(store_nodes.begin(), out_node); + } + } + } + remove_nodes.push_back(node); + LOG(INFO) << "remove " << node->id() << " node."; + break; + } + if (!found) { + candidates_map[node_type].push_back(node); + } + } + } + // Node should be deleted before node data. + RemoveNodes(graph, remove_nodes); + RemoveNodes(graph, remove_nodes_data); +} + +void CommonSubexpressionEliminationPass(Graph* graph) { + VLOG(3) << "CommonSubexpressionEliminationPass...!"; + std::unordered_map> candidates_map; + InputToNodeMap in2node; + auto store_nodes = std::get<0>(graph->topological_order()); + + for (auto& graph_node : store_nodes) { + auto node = graph_node->safe_as(); + if (node) { + for (auto& in_edge : node->inlinks_in_order(true)) { + auto* source_node = in_edge->source()->safe_as(); + in2node[source_node->id()].insert(node); + } + } + } + + CommonSubexpressionElimination(graph, store_nodes, in2node); + VLOG(3) << "CommonSubexpressionEliminationPass Finish...!"; +} +} // namespace pass +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(CommonSubexpressionEliminationPass) { + CINN_REGISTER_PASS(CommonSubexpressionEliminationPass) + .describe("This pass will remove these same sub-expression.") + .set_change_structure(true) + .set_body(cinn::hlir::pass::CommonSubexpressionEliminationPass); + + return true; +} diff --git a/cinn/hlir/pass/common_subexpression_elimination_test.cc b/cinn/hlir/pass/common_subexpression_elimination_test.cc new file mode 100644 index 0000000000..f27cdb483f --- /dev/null +++ b/cinn/hlir/pass/common_subexpression_elimination_test.cc @@ -0,0 +1,194 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) 202 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/utils/data_util.h" + +DEFINE_string(model_dir, "", ""); + +namespace cinn { +namespace frontend { + +using hlir::framework::Scope; +using utils::Join; + +TEST(common_subexpression_elimination, common_subexpression_elimination_case1) { + Placeholder A(Float(32), {32, 16, 1}, "A"); + Placeholder B(Float(32), {32, 1, 1}, "B", true); + + Program program; + auto add_1 = program.add(A, B); + auto add_2 = program.add(B, A); + auto add = program.add(add_1, add_2); + auto t_1 = program.transpose(add, {2, 1, 0}); + auto t_2 = program.transpose(add, {2, 1, 0}); + auto t_3 = program.transpose(add, {2, 0, 1}); + auto concat = program.concat({t_1, t_2, t_3}); + auto max = program.reduce_max(concat, {0}, true); + + Target target = common::DefaultTarget(); + program.SetInputs({A, B}); + program.Validate(); + LOG(INFO) << "Program:\n" << program; + auto graph = std::make_shared(program, target); + LOG(INFO) << "graph:\n" << graph->Visualize(); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); + hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); + auto scope = BuildScope(target, graph); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); + ASSERT_EQ(prerun_instrs.size(), 0); + ASSERT_EQ(run_instrs.size(), 6); + + scope->Var("A"); + scope->Var("B"); + + auto A1 = scope->GetTensor("A"); + auto B1 = scope->GetTensor("B"); + SetRandData(A1, target); + SetRandData(B1, target); + + LOG(INFO) << "graph:\n" << graph->Visualize(); + runtime_program->Execute(); +} + +TEST(common_subexpression_elimination, common_subexpression_elimination_case2) { + Placeholder A(Float(32), {32, 16}, "A"); + Placeholder B(Float(32), {32, 1}, "B", true); + + Program program; + auto add_1 = program.add(A, A); + auto add_2 = program.add(A, A); + auto reshape_1 = program.reshape(B, {4, -1}); + auto reshape_2 = program.reshape(B, {4, 8}); + auto concat_1 = program.concat({reshape_1, reshape_2}); + auto concat_2 = program.concat({reshape_1, reshape_2}); + auto concat_3 = program.concat({reshape_1, reshape_2}, 1); + + Target target = common::DefaultTarget(); + program.SetInputs({A, B}); + program.Validate(); + LOG(INFO) << "Program:\n" << program; + auto graph = std::make_shared(program, target); + LOG(INFO) << "graph:\n" << graph->Visualize(); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); + hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); + auto scope = BuildScope(target, graph); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); + ASSERT_EQ(prerun_instrs.size(), 0); + ASSERT_EQ(run_instrs.size(), 4); + + scope->Var("A"); + scope->Var("B"); + + auto A1 = scope->GetTensor("A"); + auto B1 = scope->GetTensor("B"); + SetRandData(A1, target); + SetRandData(B1, target); + + LOG(INFO) << "graph:\n" << graph->Visualize(); + runtime_program->Execute(); +} + +#ifdef CINN_WITH_CUDA +TEST(common_subexpression_elimination, common_subexpression_elimination_case3) { + Placeholder A(Float(32), {1, 3, 224, 224}, "A"); + Placeholder B(Float(32), {1, 1, 224, 224}, "B", true); + + absl::flat_hash_map attrs; + attrs["stride"] = std::vector({2, 2}); + attrs["dilation"] = std::vector({1, 1}); + attrs["padding"] = std::vector({3, 3}); + std::string src_layout = "NCHW"; + attrs["data_format"] = src_layout; + + Program program; + auto add_1 = program.add(A, B); + auto weight_1 = program.fill_constant({64, 3, 7, 7}, 1.0f, "", false, "w1"); + auto weight_2 = program.fill_constant({64, 3, 7, 7}, 1.0f, "", false, "w2"); + auto bias = program.fill_constant({1, 64, 112, 112}, 2.0f, "", false, "b1"); + auto conv_1 = program.conv2d(add_1, weight_1, attrs); + auto add_2 = program.add(conv_1, bias); + auto relu_1 = program.relu(add_2); + auto conv_2 = program.conv2d(add_1, weight_2, attrs); + auto add_3 = program.add(conv_2, bias); + auto relu_2 = program.relu(add_3); + auto out1 = program.add(relu_1, add_2); + auto out2 = program.add(add_2, relu_2); + + Target target = common::DefaultTarget(); + program.SetInputs({A, B}); + program.Validate(); + LOG(INFO) << "Program:\n" << program; + std::unordered_set fetch_list; + fetch_list.insert(out1->id); + fetch_list.insert(out2->id); + auto graph = std::make_shared(program, fetch_list, target); + LOG(INFO) << "graph:\n" << graph->Visualize(); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); + hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass"); + auto scope = BuildScope(target, graph); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + auto& prerun_instrs = runtime_program->GetPreRunInstructions(); + auto& run_instrs = runtime_program->GetRunInstructions(); + ASSERT_EQ(prerun_instrs.size(), 0); + ASSERT_EQ(run_instrs.size(), 7); + scope->Var("A"); + scope->Var("B"); + + auto A1 = scope->GetTensor("A"); + auto B1 = scope->GetTensor("B"); + SetRandData(A1, target); + SetRandData(B1, target); + + LOG(INFO) << "graph:\n" << graph->Visualize(); + runtime_program->Execute(); +} +#endif + +} // namespace frontend +} // namespace cinn diff --git a/cinn/hlir/pass/use_pass.h b/cinn/hlir/pass/use_pass.h index 06a568fda7..cf4b333487 100644 --- a/cinn/hlir/pass/use_pass.h +++ b/cinn/hlir/pass/use_pass.h @@ -26,6 +26,8 @@ CINN_USE_REGISTER(DotMerger) CINN_USE_REGISTER(OpFusionPass) CINN_USE_REGISTER(FusionMergePass) CINN_USE_REGISTER(CheckFusionAccuracyPass) + +CINN_USE_REGISTER(CommonSubexpressionEliminationPass) CINN_USE_REGISTER(TransToCustomCallPass) CINN_USE_REGISTER(DenseMergePass) CINN_USE_REGISTER(ConstantFolding) diff --git a/cinn/runtime/flags.cc b/cinn/runtime/flags.cc index 50be0d47f4..9af538115c 100644 --- a/cinn/runtime/flags.cc +++ b/cinn/runtime/flags.cc @@ -41,6 +41,10 @@ DEFINE_bool(cinn_use_op_fusion, BoolFromEnv("FLAGS_cinn_use_op_fusion", true), " DEFINE_bool(cinn_use_cublas_gemm, BoolFromEnv("FLAGS_cinn_use_cublas_gemm", true), "Whether to use cublas gemm."); +DEFINE_bool(cinn_use_common_subexpression_elimination, + BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", false), + "Whether to use common subexpression elimination pass."); + DEFINE_string(cinn_custom_call_deny_ops, StringFromEnv("FLAGS_cinn_custom_call_deny_ops", ""), "a blacklist of op are denied by MarkCustomCallOps pass, separated by ;");