Skip to content

Optional garbage collection and CheckpointManager._global_deps#187

Merged
Ig-dolci merged 24 commits intomasterfrom
dolci/optional_gccollect
Feb 18, 2025
Merged

Optional garbage collection and CheckpointManager._global_deps#187
Ig-dolci merged 24 commits intomasterfrom
dolci/optional_gccollect

Conversation

@Ig-dolci
Copy link
Contributor

@Ig-dolci Ig-dolci commented Dec 13, 2024

PR Description

  • Enable the user to apply the garbage collection if necessary:
    This PR introduces garbage collection optional support during checkpointing to enable the user to handle the lack of Python to properly track and clean up checkpoint objects in memory.

Experiment details used to test the garbage collection imposed manually

  • Degree of Freedom (DoFs): 40,401
  • Test Type: Burgers test
  • Total Steps: 1,000

The black curve represents the scenario with garbage collection enabled, while the blue curve shows the case without garbage collection during checkpointing.
Memory Usage

  • Checkpoint Manager _global_deps:
    A private attribute, _global_deps, is introduced in the CheckpointManager class. This attribute stores dependencies that are used at each time step and are not time-dependent.
    • If a block variable is included in _global_deps, it will not be cleaned during checkpointing. This prevents unnecessary cleanup and re-creation of checkpoints for dependencies that do not change with time.

@jrmaddison
Copy link
Contributor

Is it equivalent to instead drop zero output Blocks on the tape for garbage collection? Or is this slightly different?

@Ig-dolci
Copy link
Contributor Author

Is it equivalent to instead drop zero output Blocks on the tape for garbage collection? Or is this slightly different?

I believe it is different. I noticed that mainly during the recomputation process, memory usage kept growing, even after I cleared the checkpoint using block_variable._checkpoint = None. After some discussions here, my hypothesis is that Python might not be tracking all objects in memory properly. So, I am only allowing the user to employ the garbage collector manually, which looks like is helping.

@Ig-dolci Ig-dolci changed the title Optional garbage collection and more... Optional garbage collection and CheckpointManager._global_deps Feb 5, 2025
@Ig-dolci Ig-dolci marked this pull request as ready for review February 5, 2025 17:59
Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could do with a lot more explanation. This is very complicated so adding some substantial comments and expanding docstrings would be extremely helpful.

The code style seems fine.

Args:
schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package.
tape (Tape): A list of blocks :class:`Block` instances.
gc_timestep_frequency (int): The timestep frequency for garbage collection.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be clearer. Perhaps "the number of timesteps between garbage collections"

Also it should state that if None then no collection is done, or similar.

Comment on lines 88 to 89
# The user can manually invoke the garbage collector if Python fails to
# track and clean all checkpoint objects in memory properly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing because setting gc_timestep_frequency suggests that GC is being run automatically, whereas here you say manually

for deps in self.tape.timesteps[timestep - 1].checkpointable_state:
self._global_deps.add(deps)
else:
deps_to_clear = self._global_deps - self._global_deps.intersection(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Check if the block variables stored at `self._global_deps` are still
# dependencies in the previous timestep. If not, will remove them from the
# global dependencies.
deps_to_clear = self._global_deps.difference(self._global_deps.intersection(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to have the intersection here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I could be wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check this.

Comment on lines 197 to 198
# Clear the checkpoint once it is not a global dependency and should be stored
# only in the ``self.tape.timesteps`` checkpoints when needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid I don't quite understand what this means. Could you rephrase this?

Copy link
Contributor Author

@Ig-dolci Ig-dolci Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to understand this text:

For no global dependencies, checkpoint storage occurs at a self.tape timestep only when required by an action from the schedule. Thus, we have to clear the checkpoint of block variables excluded from the self._global_deps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that's good. Thanks.

Ig-dolci and others added 2 commits February 6, 2025 17:29
Co-authored-by: Connor Ward <c.ward20@imperial.ac.uk>
@angus-g
Copy link
Contributor

angus-g commented Feb 11, 2025

I think something that could be worth looking at here is the circular reference between OverloadedType and BlockVariable:

def create_block_variable(self):
self.block_variable = BlockVariable(self)
return self.block_variable

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

@Ig-dolci
Copy link
Contributor Author

I think something that could be worth looking at here is the circular reference between OverloadedType and BlockVariable:

def create_block_variable(self):
self.block_variable = BlockVariable(self)
return self.block_variable

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

Thank you. I will investigate that.

@connorjward
Copy link
Contributor

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

Is there any chance that this could be made a weakref so as to avoid this cycle?

@dham
Copy link
Member

dham commented Feb 11, 2025

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

Is there any chance that this could be made a weakref so as to avoid this cycle?

I think so. I think the BlockVariable should not prolong the lifetime of the Overloaded Type, so BlockVariable.output should be a weakref. We have to be careful that we're not abusing Blockvariable.output anywhere (i.e. relying on it as a source of information after the operation has been taped).

@Ig-dolci
Copy link
Contributor Author

In effect, any subclass of OverloadedType can only be deleted after garbage collection, not reference counting. I think this means that several data arrays sit around for longer than they should, rather than being deleted when their owner goes out of scope.

Is there any chance that this could be made a weakref so as to avoid this cycle?

I already tried weakref for BlockVariables.output, but I hit a number of errors I do not remember now. It is a very careful work to do.

@angus-g
Copy link
Contributor

angus-g commented Feb 11, 2025

I had a hacky go at that before, which worked for the forward run of the tape (with some implementation ugliness). The underlying OverloadedType was deleted at some point before/during the adjoint call, so that might need some care.

@jrmaddison
Copy link
Contributor

Looks easier to break the cycle on the other side, see #194 for an attempt.

@Ig-dolci
Copy link
Contributor Author

Ig-dolci commented Feb 12, 2025

To make you updated:

I have tested this PR merged to the PR 194 against the PR 194 (only) for Burgers' equation using the following setup: 40,000 DoFs and 1,000 time steps.

The chart below uses SingleDiskStorageSchedule with the fixing from PR 4020 . The black line represents the results related to the PR 194 (only), and the blue line represents this PR merged to the PR 194 using gc_timestep_frequency=100.
mem_gc

@Ig-dolci
Copy link
Contributor Author

I will also check the PR 4033 using the same example and add it here.

@Ig-dolci
Copy link
Contributor Author

Now using firedrake PR 4033 merged to firedrake PR 4020 that automatically uses SingleDiskStorageSchedule.

Again, I have tested this PR merged to the pyadjoint PR 194 against the PR 194 (only) for Burgers' equation using the following setup: 40,000 DoFs and 1,000 time steps.

The black line represents the results related to the PR 194 (only), and the blue line represents this PR merged to the PR 194 using gc_timestep_frequency=100.
final

# Check if the block variables stored at `self._global_deps` are still
# dependencies in the previous timestep. If not, will remove them from the
# global dependencies.
deps_to_clear = self._global_deps.difference(self._global_deps.intersection(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check this.

@Ig-dolci Ig-dolci requested a review from connorjward February 18, 2025 11:41
Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Very readable now.

@Ig-dolci Ig-dolci merged commit c7bf2ec into master Feb 18, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants