-
Notifications
You must be signed in to change notification settings - Fork 177
Description
We are debugging a large-scale Stokes optimisation problem that eventually runs out of memory (g-adopt/g-adopt#160). Since we are dealing with millions of degrees of freedom, we rely on checkpointing to disk to manage memory. While testing smaller reproducer cases, we see unexpected memory growth even after the tape is generated and throughout forward and backward passes.
What We Expected vs. What’s Happening
• Expected: Once the tape is populated, and after the first calls to ReducedFunctional.__call__ and ReducedFunctional.derivative, memory usage should stay constant.
• Actual: Memory keeps increasing with every forward and derivative call and steadily.
• In our minimal reproducer, checkpointing to disk actually increases memory usage!!!
Minimal Reproducer
Using memory_profile I profile repeated calls to ReducedFunctional.__call__ and ReducedFunctional.derivative. Simply run mprof run ...
Code for Reproduction
from firedrake import *
from firedrake.adjoint import *
import gc
def test():
T_c, rf = rf_generator()
for i in range(5):
gc.collect()
rf.__call__(T_c)
gc.collect()
rf.derivative()
def rf_generator():
tape = get_working_tape()
tape.clear_tape()
continue_annotation()
enable_disk_checkpointing()
mesh = RectangleMesh(100, 100, 1.0, 1.0)
mesh = checkpointable_mesh(mesh)
V = VectorFunctionSpace(mesh, "CG", 2)
Q = FunctionSpace(mesh, "CG", 1)
X = SpatialCoordinate(mesh)
w = Function(V, name="rotation").interpolate(as_vector([-X[1] - 0.5, X[0] - 0.5]))
T_c = Function(Q, name="control")
T = Function(Q, name="Temperature")
T_c.interpolate(0.1 * exp(-0.5 * ((X - as_vector((0.75, 0.5))) / Constant(0.1)) ** 2))
control = Control(T_c)
T.assign(T_c)
for i in range(20):
T.interpolate(T + inner(grad(T), w) * Constant(0.0001))
objective = assemble(T**2 * dx)
pause_annotation()
return T_c, ReducedFunctional(objective, control)
if __name__ == "__main__":
test()with checkpointing to disk:
Am I missing here or there is an actual leak here, specially when checkpointing to disk?

