diff --git a/cinn/frontend/pass/remove_identity.cc b/cinn/frontend/pass/remove_identity.cc index 7aaddd0c20..42e6799e5e 100644 --- a/cinn/frontend/pass/remove_identity.cc +++ b/cinn/frontend/pass/remove_identity.cc @@ -216,12 +216,12 @@ class RemoveIdentityPass : public ProgramPass { bool can_output_var_removed = !fetch_ids.count(output_var->id); if (can_input_var_removed || can_output_var_removed) { bool updated = false; - if (can_input_var_removed) { - updated = UpdateOrigin2New(input_var, output_var); - } - if (!updated && can_output_var_removed) { + if (can_output_var_removed) { updated = UpdateOrigin2New(output_var, input_var); } + if (!updated && can_input_var_removed) { + updated = UpdateOrigin2New(input_var, output_var); + } if (updated) { VLOG(3) << "Remove the " << i << "-th instruction: " << instr; remove_idxs_.insert(i); diff --git a/cinn/frontend/pass/remove_identity_test.cc b/cinn/frontend/pass/remove_identity_test.cc index 3c2d8d07cd..87833f4052 100644 --- a/cinn/frontend/pass/remove_identity_test.cc +++ b/cinn/frontend/pass/remove_identity_test.cc @@ -117,7 +117,7 @@ TEST(RemoveIdentity, cannot_remove_fetch) { std::vector output_names = {identity_2->id, mul_1->id}; std::vector program_passes = {"RemoveIdentity"}; int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names); - ASSERT_EQ(num_removed_ops, 2); + ASSERT_EQ(num_removed_ops, 1); } } // namespace cinn::frontend