-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optional garbage collection and CheckpointManager._global_deps #187
base: master
Are you sure you want to change the base?
Conversation
Is it equivalent to instead drop zero output |
I believe it is different. I noticed that mainly during the recomputation process, memory usage kept growing, even after I cleared the checkpoint using block_variable._checkpoint = None. After some discussions here, my hypothesis is that Python might not be tracking all objects in memory properly. So, I am only allowing the user to employ the garbage collector manually, which looks like is helping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could do with a lot more explanation. This is very complicated so adding some substantial comments and expanding docstrings would be extremely helpful.
The code style seems fine.
pyadjoint/checkpointing.py
Outdated
@@ -37,6 +38,11 @@ class CheckpointManager: | |||
Args: | |||
schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package. | |||
tape (Tape): A list of blocks :class:`Block` instances. | |||
gc_timestep_frequency (int): The timestep frequency for garbage collection. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be clearer. Perhaps "the number of timesteps between garbage collections"
Also it should state that if None
then no collection is done, or similar.
pyadjoint/checkpointing.py
Outdated
# The user can manually invoke the garbage collector if Python fails to | ||
# track and clean all checkpoint objects in memory properly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing because setting gc_timestep_frequency
suggests that GC is being run automatically, whereas here you say manually
pyadjoint/checkpointing.py
Outdated
for deps in self.tape.timesteps[timestep - 1].checkpointable_state: | ||
self._global_deps.add(deps) | ||
else: | ||
deps_to_clear = self._global_deps - self._global_deps.intersection( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you might want set.difference
https://docs.python.org/3/library/stdtypes.html#frozenset.difference
@@ -164,11 +175,22 @@ def _(self, cp_action, timestep): | |||
): | |||
for package in self.tape._package_data.values(): | |||
package.continue_checkpointing() | |||
if timestep == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would want a big comment here explaining what is going on, perhaps even with an example. What are _global_deps
? Why does removing _checkpoint
do the right thing?
# ``self._global_deps`` stores checkpoint dependencies that remain unchanged across | ||
# timesteps (``self.tape.timesteps``). During the forward taping process, the code | ||
# checks whether a dependency is in ``self._global_deps`` to avoid unnecessary clearing | ||
# and recreation of its checkpoint data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@connorjward Here, I try to explain why self._global_deps
and what is going to store.
# Check if the block variables stored at `self._global_deps` are still | ||
# dependencies in the previous timestep. If not, will remove them from the | ||
# global dependencies. | ||
deps_to_clear = self._global_deps.difference(self._global_deps.intersection( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to have the intersection
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I could be wrong.
pyadjoint/checkpointing.py
Outdated
# Clear the checkpoint once it is not a global dependency and should be stored | ||
# only in the ``self.tape.timesteps`` checkpoints when needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid I don't quite understand what this means. Could you rephrase this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to understand this text:
For no global dependencies, checkpoint storage occurs at a self.tape
timestep only when required by an action from the schedule. Thus, we have to clear the checkpoint of block variables excluded from the self._global_deps
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think that's good. Thanks.
Co-authored-by: Connor Ward <[email protected]>
I think something that could be worth looking at here is the circular reference between pyadjoint/pyadjoint/overloaded_type.py Lines 96 to 98 in c7d7392
In effect, any subclass of |
Thank you. I will investigate that. |
Is there any chance that this could be made a weakref so as to avoid this cycle? |
I think so. I think the BlockVariable should not prolong the lifetime of the Overloaded Type, so |
I already tried weakref for |
I had a hacky go at that before, which worked for the forward run of the tape (with some implementation ugliness). The underlying OverloadedType was deleted at some point before/during the adjoint call, so that might need some care. |
Looks easier to break the cycle on the other side, see #194 for an attempt. |
To make you updated: I have tested this PR merged to the PR 194 against the PR 194 (only) for Burgers' equation using the following setup: 40,000 DoFs and 1,000 time steps. The chart below uses |
I will also check the PR 4033 using the same example and add it here. |
Now using firedrake PR 4033 merged to firedrake PR 4020 that automatically uses Again, I have tested this PR merged to the pyadjoint PR 194 against the PR 194 (only) for Burgers' equation using the following setup: 40,000 DoFs and 1,000 time steps. The black line represents the results related to the PR 194 (only), and the blue line represents this PR merged to the PR 194 using |
PR Description
This PR introduces garbage collection optional support during checkpointing to enable the user to handle the lack of Python to properly track and clean up checkpoint objects in memory.
Experiment details used to test the garbage collection imposed manually
The black curve represents the scenario with garbage collection enabled, while the blue curve shows the case without garbage collection during checkpointing.
![Memory Usage](https://private-user-images.githubusercontent.com/63597005/405187206-73766c10-9bfd-48f5-a041-56de4fca981f.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2NTgzMzIsIm5iZiI6MTczOTY1ODAzMiwicGF0aCI6Ii82MzU5NzAwNS80MDUxODcyMDYtNzM3NjZjMTAtOWJmZC00OGY1LWEwNDEtNTZkZTRmY2E5ODFmLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE1VDIyMjAzMlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWIzMmMyNDJmOWVmNmNhMDY3NDI5YTJjZjU0ZDFkM2Y1ODk1YWY5MDZjOTM2YmM2ZTczZTVkMWJlZjQ2YjBmZGQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.xZnjSw-z_3ZO0GozZ7wk4BWwOvUQercy_yuDUMS-xHQ)
A private attribute,
_global_deps
, is introduced in theCheckpointManager
class. This attribute stores dependencies that are used at each time step and are not time-dependent._global_deps
, it will not be cleaned during checkpointing. This prevents unnecessary cleanup and re-creation of checkpoints for dependencies that do not change with time.