Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
[Add Pass] CommonSubexpressionElimination (#1166)
Browse files Browse the repository at this point in the history
* add hlir CSE pass
  • Loading branch information
zrr1999 authored Feb 16, 2023
1 parent 77450cc commit a9630cc
Show file tree
Hide file tree
Showing 6 changed files with 517 additions and 10 deletions.
19 changes: 9 additions & 10 deletions cinn/frontend/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DECLARE_bool(cinn_use_fill_constant_folding);
DECLARE_bool(cinn_use_op_fusion);
DECLARE_bool(cinn_use_cublas_gemm);
DECLARE_bool(cinn_use_common_subexpression_elimination);
DECLARE_bool(cinn_check_fusion_accuracy_pass);
DECLARE_bool(cinn_use_custom_call);

Expand Down Expand Up @@ -63,22 +64,21 @@ OptimizeOptions DefaultTrainingOptimizeOptions() {
options.program_passes.emplace_back("DeadCodeEliminate");

options.graph_passes = {"ConstantFolding"};
#ifdef CINN_WITH_CUDA
if (FLAGS_cinn_use_cublas_gemm) {
// options.graph_passes.push_back("DenseMergePass");
options.graph_passes.push_back("TransToCustomCallPass");
}
#endif
// options.graph_passes.push_back("DenseMergePass");

if (FLAGS_cinn_use_custom_call) {
options.graph_passes.emplace_back("TransToCustomCallPass");
}

if (FLAGS_cinn_use_op_fusion) {
options.graph_passes.push_back("OpFusionPass");
options.graph_passes.push_back("FusionMergePass");
options.graph_passes.emplace_back("OpFusionPass");
options.graph_passes.emplace_back("FusionMergePass");
} else {
options.graph_passes.push_back("BuildNonFusedGroupsPass");
options.graph_passes.emplace_back("BuildNonFusedGroupsPass");
}

if (FLAGS_cinn_use_common_subexpression_elimination) {
options.graph_passes.emplace_back("CommonSubexpressionEliminationPass");
}

// WARNING: the pass must be the last pass !!!
Expand All @@ -87,7 +87,6 @@ OptimizeOptions DefaultTrainingOptimizeOptions() {
// error and exited.
options.graph_passes.emplace_back("CheckFusionAccuracyPass");
}

return options;
}

Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/pass/CMakeLists.txt
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ gather_srcs(cinnapi_src SRCS
dot_merger.cc
check_fusion_accuracy_pass.cc
custom_call_pass.cc
common_subexpression_elimination.cc
constant_folding_pass.cc
dce_pass.cc
dense_merge_pass.cc
Expand All @@ -33,4 +34,5 @@ if (NOT WITH_CUDA)
endif()
cc_test(test_dot_merger SRCS test_dot_merger.cc DEPS cinncore)
cc_test(test_dce_pass SRCS dce_pass_test.cc DEPS cinncore)
cc_test(test_common_subexpression_elimination SRCS common_subexpression_elimination_test.cc DEPS cinncore)
cc_test(test_constant_folding_pass SRCS constant_folding_pass_test.cc DEPS cinncore)
306 changes: 306 additions & 0 deletions cinn/hlir/pass/common_subexpression_elimination.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include <unordered_set>

#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_lowering.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/pass/use_pass.h"
#include "cinn/utils/string.h"

namespace cinn {
namespace hlir {
namespace pass {

using framework::Graph;
using framework::Node;
using framework::NodeData;

using common::GraphEdge;
using common::GraphNode;

using InputToNodeMap = std::unordered_map<std::string, std::unordered_set<Node*>>;
using shape_dict_t = absl::flat_hash_map<std::string, framework::shape_t>;

std::unordered_set<std::string> unordered_ops = {
"elementwise_add",
"elementwise_mul",
"max",
"min",
"logical_and",
"logical_or",
"logical_xor",
"equal",
"not_equal",
"bitwise_or",
"bitwise_xor",
"bitwise_and",
};

// When all the inputs are the same, those ops just ensure that all the outputs shape is the same.
std::unordered_set<std::string> reshape_ops = {
"reshape",
"concat",
};

// Those special attrs maybe different but equivalent.
std::unordered_map<std::string, int> special_attrs = {
// {"axis", 1}, // due to the issue in some ops
// {"dim", 1}, // due to the issue in some ops
{"axes", 2},
{"perm", 2}};

bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) {
// Get the input edges for op1 and op2 in order.
auto op1_in_edges = op1->inlinks_in_order(true);
auto op2_in_edges = op2->inlinks_in_order(true);
// Get the number of input edges for op1 and op2
auto op1_inputs_size = op1_in_edges.size();
auto op2_inputs_size = op2_in_edges.size();
// If the number of input edges is not the same, the subexpression is not the same.
if (op1_inputs_size != op2_inputs_size) {
return false;
}
// Get the number of attributes for op1 and op2.
auto op1_attrs_size = op1->attrs.attr_store.size();
auto op2_attrs_size = op2->attrs.attr_store.size();
// If the number of attributes is not the same, the subexpression is not the same.
if (op1_attrs_size != op2_attrs_size) {
return false;
}
// Check if the input nodes match.
if (unordered_ops.count(op1->op()->name)) {
// For unordered ops, check if any input node of op2 matches any input node of op1.
for (auto& op1_edge : op1_in_edges) {
auto* op1_source_node = op1_edge->source()->safe_as<NodeData>();
CHECK(op1_source_node);
bool op1_equal_op2 = std::any_of(op2_in_edges.begin(), op2_in_edges.end(), [&](common::Shared<GraphEdge>& edge) {
auto* op2_source_node = edge->source()->safe_as<NodeData>();
CHECK(op2_source_node);
if (op1_source_node->id() == op2_source_node->id()) {
return true;
}
return false;
});
if (!op1_equal_op2) {
return false;
}
}
} else {
// For ordered ops, check if the input nodes match one-to-one.
for (int i = 0; i < op1_inputs_size; ++i) {
auto* op1_source_node = op1_in_edges[i]->source()->safe_as<NodeData>();
auto* op2_source_node = op2_in_edges[i]->source()->safe_as<NodeData>();
CHECK(op1_source_node);
CHECK(op2_source_node);
if (op1_source_node->id() != op2_source_node->id()) {
return false;
}
}
}

// Check if the number of dimensions is the same.
auto* op1_sink_node = GetNodeData(op1);
auto* op2_sink_node = GetNodeData(op2);
if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) {
return false;
}
if (reshape_ops.count(op1->op()->name)) {
// For reshape ops, check if the reshaped shape is the same.
return shape_dict[op1_sink_node->id()] == shape_dict[op2_sink_node->id()];
} else {
// For non-reshape ops, check if the attributes is the same.
return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) {
if (!op2->attrs.attr_store.count(attr.first)) {
return false;
}
auto& attr1 = attr.second;
auto& attr2 = op2->attrs.attr_store[attr.first];
auto ndim = static_cast<int>(shape_dict[op1_sink_node->id()].size());
if (special_attrs.count(attr.first)) {
switch (special_attrs[attr.first]) {
case 1: {
auto op1_axis = absl::get<int>(attr1);
auto op2_axis = absl::get<int>(attr2);
if (op1_axis < 0) {
op1_axis += ndim;
}
if (op2_axis < 0) {
op2_axis += ndim;
}
return op2_axis == op1_axis;
}
case 2: {
auto& op1_axes = absl::get<std::vector<int>>(attr1);
auto& op2_axes = absl::get<std::vector<int>>(attr2);
auto op1_size = op1_axes.size();
auto op2_size = op2_axes.size();
if (op1_size != op2_size) {
return false;
}
for (int i = 0; i < op1_axes.size(); ++i) {
int op1_axis = op1_axes[i];
int op2_axis = op2_axes[i];
if (op1_axis < 0) {
op1_axis += ndim;
}
if (op2_axis < 0) {
op2_axis += ndim;
}
if (op2_axis != op1_axis) {
return false;
}
}
return true;
}
}
}
return attr1 == attr2;
});
}
}

void RemoveNodes(framework::Graph* graph, std::vector<Node*>& nodes) {
for (auto* node : nodes) {
auto in_edges = node->inlinks();
for (auto& edge : in_edges) {
auto* in_node = edge->source();
in_node->UnLinkSingleTo(node);
}
auto out_edges = node->outlinks();
for (auto& edge : out_edges) {
auto* out_node = edge->sink();
node->UnLinkSingleTo(out_node);
}
graph->DropNode(node);
}
}

void RemoveNodes(framework::Graph* graph, std::vector<NodeData*>& nodes_data) {
for (auto* data : nodes_data) {
if (std::find(graph->outputs.begin(), graph->outputs.end(), data) != graph->outputs.end()) {
return;
}
graph->DropNode(data);
}
}

void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) {
std::vector<NodeData*> in_nodes;
for (auto& in_edge : trt->inlinks_in_order(true)) {
auto* in_node = in_edge->source()->safe_as<NodeData>();
in_node->UnLinkSingleTo(trt);
if (in_node->id() == src_old->id()) {
src_new->LinkTo(trt);
} else {
in_node->LinkTo(trt);
}
}
}

void CommonSubexpressionElimination(Graph* graph, std::vector<GraphNode*> store_nodes, InputToNodeMap in2node) {
std::unordered_map<std::string, std::vector<Node*>> candidates_map;
auto shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, framework::shape_t>>("infershape");
std::vector<Node*> remove_nodes;
std::vector<NodeData*> remove_nodes_data;

while (!store_nodes.empty()) {
auto* graph_node = store_nodes[0];
store_nodes.erase(store_nodes.begin());
LOG(INFO) << "size of store_nodes is " << store_nodes.size();
auto node = graph_node->safe_as<Node>();
if (node) {
auto& node_type = node->op()->name;
auto& candidates = candidates_map[node_type];
bool found = false;
for (auto* candidate_node : candidates) {
// If node is same with candidate_node, continue the next.
if (node->id() == candidate_node->id()) continue;
// If node is different from candidate_node, continue the next.
if (!IsSameSubexpression(node, candidate_node, shape_dict)) continue;
found = true;
for (int k = 0; k < node->outlinks_in_order(true).size(); ++k) {
CHECK(node->outlinks_in_order(true).size() == candidate_node->outlinks_in_order(true).size());
auto* sink_node = node->outlinks_in_order()[k]->sink()->safe_as<NodeData>();
auto* candidate_sink_node = candidate_node->outlinks_in_order()[k]->sink()->safe_as<NodeData>();
CHECK(sink_node);
CHECK(candidate_sink_node);
remove_nodes_data.push_back(sink_node);
auto iter_sink_node = std::find(graph->outputs.begin(), graph->outputs.end(), sink_node);
if (iter_sink_node != graph->outputs.end()) {
// If sink node in outputs, the node cannot be removed.
NodeData new_sink_node(
candidate_node, sink_node->output_index, sink_node->version, sink_node->id(), sink_node->is_const());
graph->outputs.erase(iter_sink_node);
graph->outputs.push_back(&new_sink_node);
}
// Replace sink_node with candidate_sink_node in nodes linked by sink_node.
auto out_nodes = in2node[sink_node->id()];
for (auto out_node : out_nodes) {
ReplaceNode(candidate_sink_node, sink_node, out_node);
// The changed out node will be detected again.
if (std::find(store_nodes.begin(), store_nodes.end(), out_node) == store_nodes.end()) {
store_nodes.insert(store_nodes.begin(), out_node);
}
}
}
remove_nodes.push_back(node);
LOG(INFO) << "remove " << node->id() << " node.";
break;
}
if (!found) {
candidates_map[node_type].push_back(node);
}
}
}
// Node should be deleted before node data.
RemoveNodes(graph, remove_nodes);
RemoveNodes(graph, remove_nodes_data);
}

void CommonSubexpressionEliminationPass(Graph* graph) {
VLOG(3) << "CommonSubexpressionEliminationPass...!";
std::unordered_map<std::string, std::vector<Node*>> candidates_map;
InputToNodeMap in2node;
auto store_nodes = std::get<0>(graph->topological_order());

for (auto& graph_node : store_nodes) {
auto node = graph_node->safe_as<Node>();
if (node) {
for (auto& in_edge : node->inlinks_in_order(true)) {
auto* source_node = in_edge->source()->safe_as<NodeData>();
in2node[source_node->id()].insert(node);
}
}
}

CommonSubexpressionElimination(graph, store_nodes, in2node);
VLOG(3) << "CommonSubexpressionEliminationPass Finish...!";
}
} // namespace pass
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(CommonSubexpressionEliminationPass) {
CINN_REGISTER_PASS(CommonSubexpressionEliminationPass)
.describe("This pass will remove these same sub-expression.")
.set_change_structure(true)
.set_body(cinn::hlir::pass::CommonSubexpressionEliminationPass);

return true;
}
Loading

0 comments on commit a9630cc

Please sign in to comment.