Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove reference cycle in VecAccessMixin #4033

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

angus-g
Copy link
Contributor

@angus-g angus-g commented Feb 12, 2025

With an associated PETSc Vec, VecAccessMixin deferred its version property to a lambda to avoid allocating the storage until necessary. Unfortunately, this lambda creates a reference cycle to self for all users of the VecAccessMixin. Given that counter accesses should be relatively infrequent, it seems fine to look up the counter type within the method itself.

Description

Related to #4014. To benchmark, I'm using the following script (very similar to the one in the linked issue, but uses 500 timesteps, a timestepper object, and removes explicit GC calls):

from firedrake import *
from firedrake.adjoint import *
# from memory_profiler import profile

def test():
    T_c, rf = rf_generator()
    rf.fwd_call = profile(rf.__call__)
    rf.derivative = profile(rf.derivative)

    for i in range(5):
        rf.fwd_call(T_c)
        rf.derivative()

@profile
def rf_generator(checkpoint_to_disk=True):
    tape = get_working_tape()
    tape.clear_tape()
    continue_annotation()

    mesh = RectangleMesh(100, 100, 1.0, 1.0)

    if checkpoint_to_disk:
        enable_disk_checkpointing()
        mesh = checkpointable_mesh(mesh)

    V = VectorFunctionSpace(mesh, "CG", 2)
    Q = FunctionSpace(mesh, "CG", 1)

    # Define the rotation vector field
    X = SpatialCoordinate(mesh)

    w = Function(V, name="rotation").interpolate(as_vector([-X[1] - 0.5, X[0] - 0.5]))
    T_c = Function(Q, name="control")
    T = Function(Q, name="Temperature")
    T_c.interpolate(0.1 * exp(-0.5 * ((X - as_vector((0.75, 0.5))) / Constant(0.1)) ** 2))
    control = Control(T_c)
    T.assign(T_c)

    # for i in ts:
    for i in tape.timestepper(iter(range(500))):
        T.interpolate(T + inner(grad(T), w) * Constant(0.0001))

    objective = assemble(T**2 * dx)

    pause_annotation()
    return T_c, ReducedFunctional(objective, control)


if __name__ == "__main__":
    test()

I'm also running this on the #4020 branch to automatically enable the SingleDiskStorageSchedule and handle the leak of function within CheckpointFunction. On the pyadjoint side, I am using dolfin-adjoint/pyadjoint#194.

Here's a pretty simple mprof comparison:
image
In black is the base, without this branch. In blue is the base, but with gc.collect() within Block.recompute (very eager, and expensive, also doesn't apply to the derivative). In red is the result with this branch, without any explicit gc. Individual plots follow, but the rescaling means you have to look a bit more closely.

Base plot

image

GC plot

image

This PR

image

I think there is a still a bit left out there in terms of making expensive allocations delete through refcounting, and perhaps there's a more elegant way of implementing the change proposed here.

With an associated PETSc Vec, VecAccessMixin deferred its version
property to a lambda to avoid allocating the storage until necessary.
Unfortunately, this lambda creates a reference cycle to self for all
users of the VecAccessMixin. Given that counter accesses should be
relatively infrequent, it seems fine to look up the counter type within
the method itself.
Copy link

github-actions bot commented Feb 12, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake real8124 ran7374 passed716 skipped34 failed

Copy link

github-actions bot commented Feb 12, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake complex8154 ran6574 passed1556 skipped24 failed

Doesn't make sense to cache a reference to self, just return self.
Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an excellent spot!

I have absolutely no idea why this is failing tests though... AFAICT the changes you have made shouldn't impact the rest of the code.

Comment on lines +88 to +90
# we want to avoid setting dat_version = self._vec.stateGet
# in __init__ to not allocate underlying storage until necessary
# (which happens when _vec is accessed).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment probably isn't necessary. It adds context for the review but won't be helpful down the line.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants