Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a5126e9
deepcopy name and count function attrinbutes
Ig-dolci Feb 8, 2025
1cc0147
Avoid referencing function.
Ig-dolci Feb 8, 2025
496cfb0
Avoid reference the callable function name; Disk checkpoint using che…
Ig-dolci Feb 10, 2025
d41c820
Copy is not necessary
Ig-dolci Feb 10, 2025
d4faac6
flake8
Ig-dolci Feb 10, 2025
f75a906
Adding firedrake_adjoint tests from pyadjoint
Ig-dolci Feb 10, 2025
46e245d
flake8
Ig-dolci Feb 10, 2025
998db1d
lint yml
Ig-dolci Feb 10, 2025
d997152
Update .github/workflows/build.yml
Ig-dolci Feb 10, 2025
ddeae15
enable other schedules
Ig-dolci Feb 10, 2025
42f728a
lint
Ig-dolci Feb 10, 2025
c8bc486
add pytest fixture; use context manager in test_tao_bounds
Ig-dolci Feb 10, 2025
b94f290
Fix test_shape_derivatives; clear package data in fixture
Ig-dolci Feb 10, 2025
ee97c0a
linting; remove unneeded clear tapes.
Ig-dolci Feb 10, 2025
475c57f
Add checkpoint_schedules in the requirements
Ig-dolci Feb 10, 2025
6f460d6
Merge branch 'dolci/move_pyadjoint_tests_to_firedrake' into dolci/fix…
Ig-dolci Feb 10, 2025
30c533e
Enable disk storage for other shedules
Ig-dolci Feb 11, 2025
806a088
Update docs
Ig-dolci Feb 11, 2025
90f451f
Please, no mutable arguments
Ig-dolci Feb 11, 2025
2d62f68
Testing disk checkpointing
Ig-dolci Feb 11, 2025
780021a
fix conflict
Ig-dolci Feb 12, 2025
93ae83f
Fix docs; move disk checkpoint test from output to adjoint
Ig-dolci Feb 13, 2025
e7a83b3
Fix docs
Ig-dolci Feb 13, 2025
76b28f5
remove output test
Ig-dolci Feb 13, 2025
b3aa347
Test with the fixings
Ig-dolci Feb 13, 2025
56b64e4
Merge branch 'master' into dolci/fix_mem_leak_diskcheckpointing
Ig-dolci Feb 17, 2025
0c9fdfb
skip a disk checkpointing test on disk
Ig-dolci Feb 17, 2025
aef34f3
Revert the API modification for disk checkpointing
Ig-dolci Feb 21, 2025
f3ae1ee
Solver master conflict
Ig-dolci Mar 5, 2025
2321fa1
Keep master burgers test
Ig-dolci Mar 5, 2025
c36f443
Update burgers test
Ig-dolci Mar 5, 2025
0a4a518
Keep master burgers test
Ig-dolci Mar 5, 2025
92ecc25
fixing
Ig-dolci Mar 5, 2025
3712f46
Update tests/firedrake/adjoint/test_disk_checkpointing.py
Ig-dolci Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions firedrake/adjoint_utils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,23 @@ def __exit__(self, *args):


def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True):
"""Add a DiskCheckpointer to the current tape and switch on
disk checkpointing.
"""Add a DiskCheckpointer to the current tape.

Disk checkpointing is fully enabled by calling::

enable_disk_checkpointing()
tape = get_working_tape()
tape.enable_checkpointing(schedule)

Here, ``schedule`` is a checkpointing schedule from the `checkpoint_schedules
package <https://www.firedrakeproject.org/checkpoint_schedules/>`_. For example,
to checkpoint every timestep on disk, use::

from checkpoint_schedules import SingleDiskStorageSchedule
schedule = SingleDiskStorageSchedule()

`checkpoint_schedules` provides other schedules for checkpointing to memory, disk,
or a combination of both.

Parameters
----------
Expand All @@ -70,8 +85,6 @@ def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True):
tape = get_working_tape()
if "firedrake" not in tape._package_data:
tape._package_data["firedrake"] = DiskCheckpointer(dirname, comm, cleanup)
if not disk_checkpointing():
continue_disk_checkpointing()


def disk_checkpointing():
Expand Down Expand Up @@ -118,7 +131,7 @@ def __del__(self):


class DiskCheckpointer(TapePackageData):
"""Manger for the disk checkpointing process.
"""Manager for the disk checkpointing process.

Parameters
----------
Expand Down Expand Up @@ -278,7 +291,7 @@ class CheckpointFunction(CheckpointBase, OverloadedType):

def __init__(self, function):
from firedrake.checkpointing import CheckpointFile
self.name = function.name
self.name = function.name()
self.mesh = function.function_space().mesh()
self.file = current_checkpoint_file()

Expand Down Expand Up @@ -310,7 +323,7 @@ def restore(self):
function = infile.load_function(self.mesh, self.stored_name,
idx=self.stored_index)
return type(function)(function.function_space(),
function.dat, name=self.name(), count=self.count)
function.dat, name=self.name, count=self.count)

def _ad_restore_at_checkpoint(self, checkpoint):
return checkpoint.restore()
Expand Down
86 changes: 59 additions & 27 deletions tests/firedrake/adjoint/test_disk_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from firedrake.__future__ import *
from firedrake.adjoint import *
from firedrake.adjoint_utils.checkpointing import disk_checkpointing
from checkpoint_schedules import SingleDiskStorageSchedule
import numpy as np
import os
from checkpoint_schedules import SingleDiskStorageSchedule


@pytest.fixture(autouse=True)
Expand All @@ -30,7 +30,7 @@ def handle_annotation():
pause_disk_checkpointing()


def adjoint_example(fine, coarse):
def adjoint_example(fine, coarse=None):
dg_space = FunctionSpace(fine, "DG", 1)
cg_space = FunctionSpace(fine, "CG", 2)
W = dg_space * cg_space
Expand All @@ -46,16 +46,18 @@ def adjoint_example(fine, coarse):
v.assign(m)
# FunctionSplitBlock, GenericSolveBlock
u.project(v)
if coarse:
dg_space_c = FunctionSpace(coarse, "DG", 1)
cg_space_c = FunctionSpace(coarse, "CG", 2)

dg_space_c = FunctionSpace(coarse, "DG", 1)
cg_space_c = FunctionSpace(coarse, "CG", 2)

# SupermeshProjectBlock
u_c = project(u, dg_space_c)
v_c = project(v, cg_space_c)
# SupermeshProjectBlock
u_c = project(u, dg_space_c)
v_c = project(v, cg_space_c)

# AssembleBlock
J = assemble((u_c - v_c)**2 * dx)
# AssembleBlock
J = assemble((u_c - v_c)**2 * dx)
else:
J = assemble((u - v)**2 * dx)

Jhat = ReducedFunctional(J, Control(m))

Expand All @@ -77,36 +79,66 @@ def adjoint_example(fine, coarse):


@pytest.mark.skipcomplex
@pytest.mark.parametrize("checkpoint_schedule", [True, False])
def test_disk_checkpointing(checkpoint_schedule):
def test_disk_checkpointing():
# Use a Firedrake Tape subclass that supports disk checkpointing.
set_working_tape(Tape())
tape = get_working_tape()
tape.clear_tape()
enable_disk_checkpointing()
if checkpoint_schedule:
tape.enable_checkpointing(SingleDiskStorageSchedule())
tape.enable_checkpointing(SingleDiskStorageSchedule())
fine = checkpointable_mesh(UnitSquareMesh(10, 10, name="fine"))
coarse = checkpointable_mesh(UnitSquareMesh(4, 4, name="coarse"))
J_disk, grad_J_disk = adjoint_example(fine, coarse)
coarse = checkpointable_mesh(UnitSquareMesh(6, 6, name="coarse"))
J_disk, grad_J_disk = adjoint_example(fine, coarse=coarse)

assert disk_checkpointing() is False

if checkpoint_schedule:
assert disk_checkpointing() is False
tape.clear_tape()
if not checkpoint_schedule:
pause_disk_checkpointing()

J_mem, grad_J_mem = adjoint_example(fine, coarse)
J_mem, grad_J_mem = adjoint_example(fine, coarse=coarse)

assert np.allclose(J_disk, J_mem)
assert np.allclose(assemble((grad_J_disk - grad_J_mem)**2*dx), 0.0)


@pytest.mark.skipcomplex
@pytest.mark.parallel(nprocs=3)
def test_disk_checkpointing_parallel():
# Use a Firedrake Tape subclass that supports disk checkpointing.
set_working_tape(Tape())
tape = get_working_tape()
tape.clear_tape()
continue_annotation()
enable_disk_checkpointing()
tape.enable_checkpointing(SingleDiskStorageSchedule())
mesh = checkpointable_mesh(UnitSquareMesh(10, 10))
J_disk, grad_J_disk = adjoint_example(mesh)

assert disk_checkpointing() is False
tape.clear_tape()
J_mem, grad_J_mem = adjoint_example(mesh)
assert np.allclose(J_disk, J_mem)
assert np.allclose(assemble((grad_J_disk - grad_J_mem)**2*dx), 0.0)


@pytest.mark.skipcomplex
def test_disk_checkpointing_error():
def test_disk_checkpointing_successive_writes():
from firedrake.adjoint import checkpointable_mesh
tape = get_working_tape()
# check the raise of the exception
with pytest.raises(RuntimeError):
tape.enable_checkpointing(SingleDiskStorageSchedule())
assert disk_checkpointing_callback["firedrake"] == "Please call enable_disk_checkpointing() "\
"before checkpointing on the disk."
tape.clear_tape()
enable_disk_checkpointing()
tape.enable_checkpointing(SingleDiskStorageSchedule())

mesh = checkpointable_mesh(UnitSquareMesh(1, 1))

cg_space = FunctionSpace(mesh, "CG", 1)
u = Function(cg_space, name='u')
v = Function(cg_space, name='v')

u.assign(1.)
v.assign(v + 2.*u)
v.assign(v + 3.*u)

J = assemble(v*dx)
Jhat = ReducedFunctional(J, Control(u))
assert np.allclose(J, Jhat(Function(cg_space).interpolate(1.)))
assert disk_checkpointing() is False
106 changes: 0 additions & 106 deletions tests/firedrake/output/test_adjoint_disk_checkpointing.py

This file was deleted.

Loading