Optional garbage collection and CheckpointManager._global_deps#187
Optional garbage collection and CheckpointManager._global_deps#187
Conversation
|
Is it equivalent to instead drop zero output |
I believe it is different. I noticed that mainly during the recomputation process, memory usage kept growing, even after I cleared the checkpoint using block_variable._checkpoint = None. After some discussions here, my hypothesis is that Python might not be tracking all objects in memory properly. So, I am only allowing the user to employ the garbage collector manually, which looks like is helping. |
connorjward
left a comment
There was a problem hiding this comment.
I think this could do with a lot more explanation. This is very complicated so adding some substantial comments and expanding docstrings would be extremely helpful.
The code style seems fine.
pyadjoint/checkpointing.py
Outdated
| Args: | ||
| schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package. | ||
| tape (Tape): A list of blocks :class:`Block` instances. | ||
| gc_timestep_frequency (int): The timestep frequency for garbage collection. |
There was a problem hiding this comment.
This could be clearer. Perhaps "the number of timesteps between garbage collections"
Also it should state that if None then no collection is done, or similar.
pyadjoint/checkpointing.py
Outdated
| # The user can manually invoke the garbage collector if Python fails to | ||
| # track and clean all checkpoint objects in memory properly. |
There was a problem hiding this comment.
This is confusing because setting gc_timestep_frequency suggests that GC is being run automatically, whereas here you say manually
pyadjoint/checkpointing.py
Outdated
| for deps in self.tape.timesteps[timestep - 1].checkpointable_state: | ||
| self._global_deps.add(deps) | ||
| else: | ||
| deps_to_clear = self._global_deps - self._global_deps.intersection( |
There was a problem hiding this comment.
I think you might want set.difference https://docs.python.org/3/library/stdtypes.html#frozenset.difference
pyadjoint/checkpointing.py
Outdated
| # Check if the block variables stored at `self._global_deps` are still | ||
| # dependencies in the previous timestep. If not, will remove them from the | ||
| # global dependencies. | ||
| deps_to_clear = self._global_deps.difference(self._global_deps.intersection( |
There was a problem hiding this comment.
I don't think you need to have the intersection here.
There was a problem hiding this comment.
I will check this.
pyadjoint/checkpointing.py
Outdated
| # Clear the checkpoint once it is not a global dependency and should be stored | ||
| # only in the ``self.tape.timesteps`` checkpoints when needed. |
There was a problem hiding this comment.
I'm afraid I don't quite understand what this means. Could you rephrase this?
There was a problem hiding this comment.
Better to understand this text:
For no global dependencies, checkpoint storage occurs at a self.tape timestep only when required by an action from the schedule. Thus, we have to clear the checkpoint of block variables excluded from the self._global_deps.
There was a problem hiding this comment.
Yeah I think that's good. Thanks.
Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk>
|
I think something that could be worth looking at here is the circular reference between pyadjoint/pyadjoint/overloaded_type.py Lines 96 to 98 in c7d7392 In effect, any subclass of |
Thank you. I will investigate that. |
Is there any chance that this could be made a weakref so as to avoid this cycle? |
I think so. I think the BlockVariable should not prolong the lifetime of the Overloaded Type, so |
I already tried weakref for |
|
I had a hacky go at that before, which worked for the forward run of the tape (with some implementation ugliness). The underlying OverloadedType was deleted at some point before/during the adjoint call, so that might need some care. |
|
Looks easier to break the cycle on the other side, see #194 for an attempt. |
|
To make you updated: I have tested this PR merged to the PR 194 against the PR 194 (only) for Burgers' equation using the following setup: 40,000 DoFs and 1,000 time steps. The chart below uses |
|
I will also check the PR 4033 using the same example and add it here. |
|
Now using firedrake PR 4033 merged to firedrake PR 4020 that automatically uses Again, I have tested this PR merged to the pyadjoint PR 194 against the PR 194 (only) for Burgers' equation using the following setup: 40,000 DoFs and 1,000 time steps. The black line represents the results related to the PR 194 (only), and the blue line represents this PR merged to the PR 194 using |
pyadjoint/checkpointing.py
Outdated
| # Check if the block variables stored at `self._global_deps` are still | ||
| # dependencies in the previous timestep. If not, will remove them from the | ||
| # global dependencies. | ||
| deps_to_clear = self._global_deps.difference(self._global_deps.intersection( |
There was a problem hiding this comment.
I will check this.
connorjward
left a comment
There was a problem hiding this comment.
Looks great. Very readable now.


PR Description
This PR introduces garbage collection optional support during checkpointing to enable the user to handle the lack of Python to properly track and clean up checkpoint objects in memory.
Experiment details used to test the garbage collection imposed manually
The black curve represents the scenario with garbage collection enabled, while the blue curve shows the case without garbage collection during checkpointing.

A private attribute,
_global_deps, is introduced in theCheckpointManagerclass. This attribute stores dependencies that are used at each time step and are not time-dependent._global_deps, it will not be cleaned during checkpointing. This prevents unnecessary cleanup and re-creation of checkpoints for dependencies that do not change with time.