Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional garbage collection and CheckpointManager._global_deps #187

Open
wants to merge 21 commits into
base: master
Choose a base branch
from

Conversation

Ig-dolci
Copy link
Contributor

@Ig-dolci Ig-dolci commented Dec 13, 2024

PR Description

  • Enable the user to apply the garbage collection if necessary:
    This PR introduces garbage collection optional support during checkpointing to enable the user to handle the lack of Python to properly track and clean up checkpoint objects in memory.

Experiment details used to test the garbage collection imposed manually

  • Degree of Freedom (DoFs): 40,401
  • Test Type: Burgers test
  • Total Steps: 1,000

The black curve represents the scenario with garbage collection enabled, while the blue curve shows the case without garbage collection during checkpointing.
Memory Usage

  • Checkpoint Manager _global_deps:
    A private attribute, _global_deps, is introduced in the CheckpointManager class. This attribute stores dependencies that are used at each time step and are not time-dependent.
    • If a block variable is included in _global_deps, it will not be cleaned during checkpointing. This prevents unnecessary cleanup and re-creation of checkpoints for dependencies that do not change with time.

pyadjoint/tape.py Outdated Show resolved Hide resolved
pyadjoint/checkpointing.py Outdated Show resolved Hide resolved
@jrmaddison
Copy link
Contributor

Is it equivalent to instead drop zero output Blocks on the tape for garbage collection? Or is this slightly different?

@Ig-dolci
Copy link
Contributor Author

Is it equivalent to instead drop zero output Blocks on the tape for garbage collection? Or is this slightly different?

I believe it is different. I noticed that mainly during the recomputation process, memory usage kept growing, even after I cleared the checkpoint using block_variable._checkpoint = None. After some discussions here, my hypothesis is that Python might not be tracking all objects in memory properly. So, I am only allowing the user to employ the garbage collector manually, which looks like is helping.

pyadjoint/checkpointing.py Outdated Show resolved Hide resolved
pyadjoint/tape.py Outdated Show resolved Hide resolved
pyadjoint/tape.py Show resolved Hide resolved
pyadjoint/checkpointing.py Outdated Show resolved Hide resolved
pyadjoint/checkpointing.py Outdated Show resolved Hide resolved
@Ig-dolci Ig-dolci changed the title Optional garbage collection and more... Optional garbage collection and CheckpointManager._global_deps Feb 5, 2025
@Ig-dolci Ig-dolci marked this pull request as ready for review February 5, 2025 17:59
Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could do with a lot more explanation. This is very complicated so adding some substantial comments and expanding docstrings would be extremely helpful.

The code style seems fine.

@@ -37,6 +38,11 @@ class CheckpointManager:
Args:
schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package.
tape (Tape): A list of blocks :class:`Block` instances.
gc_timestep_frequency (int): The timestep frequency for garbage collection.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be clearer. Perhaps "the number of timesteps between garbage collections"

Also it should state that if None then no collection is done, or similar.

Comment on lines 88 to 89
# The user can manually invoke the garbage collector if Python fails to
# track and clean all checkpoint objects in memory properly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing because setting gc_timestep_frequency suggests that GC is being run automatically, whereas here you say manually

for deps in self.tape.timesteps[timestep - 1].checkpointable_state:
self._global_deps.add(deps)
else:
deps_to_clear = self._global_deps - self._global_deps.intersection(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -164,11 +175,22 @@ def _(self, cp_action, timestep):
):
for package in self.tape._package_data.values():
package.continue_checkpointing()
if timestep == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would want a big comment here explaining what is going on, perhaps even with an example. What are _global_deps? Why does removing _checkpoint do the right thing?

pyadjoint/checkpointing.py Show resolved Hide resolved
pyadjoint/checkpointing.py Show resolved Hide resolved
pyadjoint/checkpointing.py Show resolved Hide resolved
pyadjoint/tape.py Outdated Show resolved Hide resolved
pyadjoint/tape.py Outdated Show resolved Hide resolved
pyadjoint/tape.py Outdated Show resolved Hide resolved
Comment on lines +94 to +97
# ``self._global_deps`` stores checkpoint dependencies that remain unchanged across
# timesteps (``self.tape.timesteps``). During the forward taping process, the code
# checks whether a dependency is in ``self._global_deps`` to avoid unnecessary clearing
# and recreation of its checkpoint data.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@connorjward Here, I try to explain why self._global_deps and what is going to store.

pyadjoint/checkpointing.py Outdated Show resolved Hide resolved
# Check if the block variables stored at `self._global_deps` are still
# dependencies in the previous timestep. If not, will remove them from the
# global dependencies.
deps_to_clear = self._global_deps.difference(self._global_deps.intersection(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to have the intersection here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I could be wrong.

pyadjoint/tape.py Outdated Show resolved Hide resolved
pyadjoint/tape.py Outdated Show resolved Hide resolved
pyadjoint/checkpointing.py Outdated Show resolved Hide resolved
Comment on lines 197 to 198
# Clear the checkpoint once it is not a global dependency and should be stored
# only in the ``self.tape.timesteps`` checkpoints when needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid I don't quite understand what this means. Could you rephrase this?

Copy link
Contributor Author

@Ig-dolci Ig-dolci Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to understand this text:

For no global dependencies, checkpoint storage occurs at a self.tape timestep only when required by an action from the schedule. Thus, we have to clear the checkpoint of block variables excluded from the self._global_deps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that's good. Thanks.

@angus-g
Copy link
Contributor

angus-g commented Feb 11, 2025

I think something that could be worth looking at here is the circular reference between OverloadedType and BlockVariable:

def create_block_variable(self):
self.block_variable = BlockVariable(self)
return self.block_variable

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

@Ig-dolci
Copy link
Contributor Author

I think something that could be worth looking at here is the circular reference between OverloadedType and BlockVariable:

def create_block_variable(self):
self.block_variable = BlockVariable(self)
return self.block_variable

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

Thank you. I will investigate that.

@connorjward
Copy link
Contributor

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

Is there any chance that this could be made a weakref so as to avoid this cycle?

@dham
Copy link
Member

dham commented Feb 11, 2025

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

Is there any chance that this could be made a weakref so as to avoid this cycle?

I think so. I think the BlockVariable should not prolong the lifetime of the Overloaded Type, so BlockVariable.output should be a weakref. We have to be careful that we're not abusing Blockvariable.output anywhere (i.e. relying on it as a source of information after the operation has been taped).

@Ig-dolci
Copy link
Contributor Author

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

Is there any chance that this could be made a weakref so as to avoid this cycle?

I already tried weakref for BlockVariables.output, but I hit a number of errors I do not remember now. It is a very careful work to do.

@angus-g
Copy link
Contributor

angus-g commented Feb 11, 2025

I had a hacky go at that before, which worked for the forward run of the tape (with some implementation ugliness). The underlying OverloadedType was deleted at some point before/during the adjoint call, so that might need some care.

@jrmaddison
Copy link
Contributor

Looks easier to break the cycle on the other side, see #194 for an attempt.

@Ig-dolci
Copy link
Contributor Author

Ig-dolci commented Feb 12, 2025

To make you updated:

I have tested this PR merged to the PR 194 against the PR 194 (only) for Burgers' equation using the following setup: 40,000 DoFs and 1,000 time steps.

The chart below uses SingleDiskStorageSchedule with the fixing from PR 4020 . The black line represents the results related to the PR 194 (only), and the blue line represents this PR merged to the PR 194 using gc_timestep_frequency=100.
mem_gc

@Ig-dolci
Copy link
Contributor Author

I will also check the PR 4033 using the same example and add it here.

@Ig-dolci
Copy link
Contributor Author

Now using firedrake PR 4033 merged to firedrake PR 4020 that automatically uses SingleDiskStorageSchedule.

Again, I have tested this PR merged to the pyadjoint PR 194 against the PR 194 (only) for Burgers' equation using the following setup: 40,000 DoFs and 1,000 time steps.

The black line represents the results related to the PR 194 (only), and the blue line represents this PR merged to the PR 194 using gc_timestep_frequency=100.
final

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants