[Remat][Bugfix] Enhance the rematerialization pass to handle nested tuples #17
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR addresses issue #8 and presents a small fix for the rematerialization pass to handle nested tuples. Specifically, it handles cases like the following:
Without this patch the remat pass would error out when processing
c
because of a simplified assumption we made in the first version of the pass. Notice that this fix does not address the case where a call node directly produces a nested tuple (which does not exist in current RAF as far as I know).ResNet50 with AMP can now be compiled with rematerialization turned on, and the benchmarked latency is ~35ms per batch when the batch size is 16.
I will add a test for nested tuples later. Please review the changes first.
Checklist
cc @awslabs/raf-reviewer @comaniac