Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix remove identity should remove output first bug (#1210)
Browse files Browse the repository at this point in the history
* fix remove identity should remove output first bug
  • Loading branch information
thisjiang authored Feb 16, 2023
1 parent ec7f1a8 commit 77450cc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions cinn/frontend/pass/remove_identity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ class RemoveIdentityPass : public ProgramPass {
bool can_output_var_removed = !fetch_ids.count(output_var->id);
if (can_input_var_removed || can_output_var_removed) {
bool updated = false;
if (can_input_var_removed) {
updated = UpdateOrigin2New(input_var, output_var);
}
if (!updated && can_output_var_removed) {
if (can_output_var_removed) {
updated = UpdateOrigin2New(output_var, input_var);
}
if (!updated && can_input_var_removed) {
updated = UpdateOrigin2New(input_var, output_var);
}
if (updated) {
VLOG(3) << "Remove the " << i << "-th instruction: " << instr;
remove_idxs_.insert(i);
Expand Down
2 changes: 1 addition & 1 deletion cinn/frontend/pass/remove_identity_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ TEST(RemoveIdentity, cannot_remove_fetch) {
std::vector<std::string> output_names = {identity_2->id, mul_1->id};
std::vector<std::string> program_passes = {"RemoveIdentity"};
int num_removed_ops = tester.RunAndCheck(builder, program_passes, input_names, output_names);
ASSERT_EQ(num_removed_ops, 2);
ASSERT_EQ(num_removed_ops, 1);
}

} // namespace cinn::frontend

0 comments on commit 77450cc

Please sign in to comment.