From 77450ccadadc94f7e4e7f85352311416f9a04186 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Thu, 16 Feb 2023 18:58:04 +0800 Subject: [PATCH] fix remove identity should remove output first bug (#1210) * fix remove identity should remove output first bug --- cinn/frontend/pass/remove_identity.cc | 8 ++++---- cinn/frontend/pass/remove_identity_test.cc | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cinn/frontend/pass/remove_identity.cc b/cinn/frontend/pass/remove_identity.cc index 7aaddd0c20..42e6799e5e 100644 --- a/cinn/frontend/pass/remove_identity.cc +++ b/cinn/frontend/pass/remove_identity.cc @@ -216,12 +216,12 @@ class RemoveIdentityPass : public ProgramPass { bool can_output_var_removed = !fetch_ids.count(output_var->id); if (can_input_var_removed || can_output_var_removed) { bool updated = false; - if (can_input_var_removed) { - updated = UpdateOrigin2New(input_var, output_var); - } - if (!updated && can_output_var_removed) { + if (can_output_var_removed) { updated = UpdateOrigin2New(output_var, input_var); } + if (!updated && can_input_var_removed) { + updated = UpdateOrigin2New(input_var, output_var); + } if (updated) { VLOG(3) << "Remove the " << i << "-th instruction: " << instr; remove_idxs_.insert(i); diff --git a/cinn/frontend/pass/remove_identity_test.cc b/cinn/frontend/pass/remove_identity_test.cc index 3c2d8d07cd..87833f4052 100644 --- a/cinn/frontend/pass/remove_identity_test.cc +++ b/cinn/frontend/pass/remove_identity_test.cc @@ -117,7 +117,7 @@ TEST(RemoveIdentity, cannot_remove_fetch) { std::vector output_names = {identity_2->id, mul_1->id}; std::vector program_passes = {"RemoveIdentity"}; int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); - ASSERT_EQ(num_removed_ops, 2); + ASSERT_EQ(num_removed_ops, 1); } } // namespace cinn::frontend