Skip to content

Conversation

@angus-g
Copy link
Contributor

@angus-g angus-g commented Feb 12, 2025

With an associated PETSc Vec, VecAccessMixin deferred its version property to a lambda to avoid allocating the storage until necessary. Unfortunately, this lambda creates a reference cycle to self for all users of the VecAccessMixin. Given that counter accesses should be relatively infrequent, it seems fine to look up the counter type within the method itself.

Description

Related to #4014. To benchmark, I'm using the following script (very similar to the one in the linked issue, but uses 500 timesteps, a timestepper object, and removes explicit GC calls):

from firedrake import *
from firedrake.adjoint import *
# from memory_profiler import profile

def test():
    T_c, rf = rf_generator()
    rf.fwd_call = profile(rf.__call__)
    rf.derivative = profile(rf.derivative)

    for i in range(5):
        rf.fwd_call(T_c)
        rf.derivative()

@profile
def rf_generator(checkpoint_to_disk=True):
    tape = get_working_tape()
    tape.clear_tape()
    continue_annotation()

    mesh = RectangleMesh(100, 100, 1.0, 1.0)

    if checkpoint_to_disk:
        enable_disk_checkpointing()
        mesh = checkpointable_mesh(mesh)

    V = VectorFunctionSpace(mesh, "CG", 2)
    Q = FunctionSpace(mesh, "CG", 1)

    # Define the rotation vector field
    X = SpatialCoordinate(mesh)

    w = Function(V, name="rotation").interpolate(as_vector([-X[1] - 0.5, X[0] - 0.5]))
    T_c = Function(Q, name="control")
    T = Function(Q, name="Temperature")
    T_c.interpolate(0.1 * exp(-0.5 * ((X - as_vector((0.75, 0.5))) / Constant(0.1)) ** 2))
    control = Control(T_c)
    T.assign(T_c)

    # for i in ts:
    for i in tape.timestepper(iter(range(500))):
        T.interpolate(T + inner(grad(T), w) * Constant(0.0001))

    objective = assemble(T**2 * dx)

    pause_annotation()
    return T_c, ReducedFunctional(objective, control)


if __name__ == "__main__":
    test()

I'm also running this on the #4020 branch to automatically enable the SingleDiskStorageSchedule and handle the leak of function within CheckpointFunction. On the pyadjoint side, I am using dolfin-adjoint/pyadjoint#194.

Here's a pretty simple mprof comparison:
image
In black is the base, without this branch. In blue is the base, but with gc.collect() within Block.recompute (very eager, and expensive, also doesn't apply to the derivative). In red is the result with this branch, without any explicit gc. Individual plots follow, but the rescaling means you have to look a bit more closely.

Base plot

image

GC plot

image

This PR

image

I think there is a still a bit left out there in terms of making expensive allocations delete through refcounting, and perhaps there's a more elegant way of implementing the change proposed here.

@github-actions
Copy link

github-actions bot commented Feb 12, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake real8238 ran7524 passed714 skipped0 failed

@github-actions
Copy link

github-actions bot commented Feb 12, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake complex8317 ran6614 passed1703 skipped0 failed

Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an excellent spot!

I have absolutely no idea why this is failing tests though... AFAICT the changes you have made shouldn't impact the rest of the code.

@angus-g
Copy link
Contributor Author

angus-g commented Feb 17, 2025

I had to do a few wider modifications around the inheritance of AbstractDat and VecAccessMixin. Hopefully it passes the tests now.

@angus-g angus-g force-pushed the angus-g/vecaccess-cycle branch 3 times, most recently from afe09ea to 5edefe0 Compare February 19, 2025 04:45
angus-g added 3 commits March 4, 2025 08:43
With an associated PETSc Vec, VecAccessMixin deferred its version
property to a lambda to avoid allocating the storage until necessary.
Unfortunately, this lambda creates a reference cycle to self for all
users of the VecAccessMixin. Given that counter accesses should be
relatively infrequent, it seems fine to look up the counter type within
the method itself.
Doesn't make sense to cache a reference to self, just return self.
The inheritance chain for Dat and Global puts VecAccessMixin
(rightly) behind DataCarrier. This means that by MRO, the
increment_dat_version method provided on DataCarrier will be used,
which is a null operation. I think this makes sense, given that not all
classes use this. However, because we're providing
increment_dat_version as an override through VecAccessMixin, we need to
explicitly refer to it in the inheriting classes.
@angus-g angus-g force-pushed the angus-g/vecaccess-cycle branch from 5edefe0 to bd2f8ad Compare March 3, 2025 21:43
@github-actions
Copy link

github-actions bot commented Mar 3, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake default8201 ran7476 passed725 skipped0 failed

Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for taking so long to review this. I think that your changes demonstrate that VecAccessMixin is a poor class for this as some things have a dat_version but no Vec.

Otherwise updating a dat through a view doesn't increment the version.

Co-authored-by: Connor Ward <[email protected]>
@dham dham enabled auto-merge (squash) March 5, 2025 17:06
@dham dham merged commit b35aab5 into master Mar 5, 2025
15 checks passed
@dham dham deleted the angus-g/vecaccess-cycle branch March 5, 2025 22:23
connorjward added a commit that referenced this pull request Mar 17, 2025
* Remove reference cycle in VecAccessMixin

With an associated PETSc Vec, VecAccessMixin deferred its version
property to a lambda to avoid allocating the storage until necessary.
Unfortunately, this lambda creates a reference cycle to self for all
users of the VecAccessMixin. Given that counter accesses should be
relatively infrequent, it seems fine to look up the counter type within
the method itself.

* Don't cache topological on CoordinatelessFunction

Doesn't make sense to cache a reference to self, just return self.

* Convert increment_dat_version to abstract on DataCarrier

The inheritance chain for Dat and Global puts VecAccessMixin
(rightly) behind DataCarrier. This means that by MRO, the
increment_dat_version method provided on DataCarrier will be used,
which is a null operation. I think this makes sense, given that not all
classes use this. However, because we're providing
increment_dat_version as an override through VecAccessMixin, we need to
explicitly refer to it in the inheriting classes.

* Update pyop2/types/dat.py

Otherwise updating a dat through a view doesn't increment the version.

Co-authored-by: Connor Ward <[email protected]>

---------

Co-authored-by: David A. Ham <[email protected]>
Co-authored-by: Connor Ward <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants