The pyadjoint tape references backend variables. This means that any memory allocated for the forward variables, during the forward calculation, is referenced by the tape. This can prevent memory usage being reduced by checkpointing.
from firedrake import *
from firedrake.adjoint import *
from checkpoint_schedules import MultistageCheckpointSchedule
N = 100
mesh = UnitIntervalMesh(1)
space = FunctionSpace(mesh, "Lagrange", 1)
tape = get_working_tape()
tape.enable_checkpointing(MultistageCheckpointSchedule(N, 3, 0))
u = Function(space, name="u").interpolate(Constant(2.0))
continue_annotation()
for _ in tape.timestepper(iter(range(N))):
u_ = Function(space, name="u")
assemble(Interpolate(u + u, space), tensor=u_)
u = u_
del u_
pause_annotation()
del u
deps = set()
outputs = set()
for block in tape._blocks:
for dep in block.get_dependencies():
if isinstance(dep.output, Function):
deps.add(dep.output.count())
for dep in block.get_outputs():
if isinstance(dep.output, Function):
outputs.add(dep.output.count())
print(f"{len(deps)=}")
print(f"{len(outputs)=}")
len(deps)=100
len(outputs)=100
The pyadjoint tape references backend variables. This means that any memory allocated for the forward variables, during the forward calculation, is referenced by the tape. This can prevent memory usage being reduced by checkpointing.
Example
leads to output