From e24d0b1c4fbfca4dfbbd03e7cdf7e7b0a7661130 Mon Sep 17 00:00:00 2001 From: zrr1999 <2742392377@qq.com> Date: Mon, 30 Jan 2023 05:10:56 +0000 Subject: [PATCH] fix test case random error --- cinn/hlir/pass/common_subexpression_elimination.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cinn/hlir/pass/common_subexpression_elimination.cc b/cinn/hlir/pass/common_subexpression_elimination.cc index 9c6001e0fc..c7b4ea6705 100644 --- a/cinn/hlir/pass/common_subexpression_elimination.cc +++ b/cinn/hlir/pass/common_subexpression_elimination.cc @@ -222,12 +222,15 @@ void CommonSubexpressionElimination(Graph* graph, std::vector store_ while (!store_nodes.empty()) { auto* graph_node = store_nodes[0]; store_nodes.erase(store_nodes.begin()); + LOG(INFO) << "size of store_nodes is " << store_nodes.size(); auto node = graph_node->safe_as(); if (node) { auto& node_type = node->op()->name; auto& candidates = candidates_map[node_type]; bool found = false; for (auto* candidate_node : candidates) { + // If node is same with candidate_node, continue the next. + if (node->id() == candidate_node->id()) continue; // If node is different from candidate_node, continue the next. if (!IsSameSubexpression(node, candidate_node, shape_dict)) continue; found = true; @@ -254,8 +257,6 @@ void CommonSubexpressionElimination(Graph* graph, std::vector store_ if (std::find(store_nodes.begin(), store_nodes.end(), out_node) == store_nodes.end()) { store_nodes.insert(store_nodes.begin(), out_node); } - out_nodes.erase(node); - out_nodes.insert(candidate_node); } } remove_nodes.push_back(node);