diff --git a/firedrake/adjoint_utils/checkpointing.py b/firedrake/adjoint_utils/checkpointing.py index 0ab337e17b..cd377afa49 100644 --- a/firedrake/adjoint_utils/checkpointing.py +++ b/firedrake/adjoint_utils/checkpointing.py @@ -50,8 +50,23 @@ def __exit__(self, *args): def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True): - """Add a DiskCheckpointer to the current tape and switch on - disk checkpointing. + """Add a DiskCheckpointer to the current tape. + + Disk checkpointing is fully enabled by calling:: + + enable_disk_checkpointing() + tape = get_working_tape() + tape.enable_checkpointing(schedule) + + Here, ``schedule`` is a checkpointing schedule from the `checkpoint_schedules + package `_. For example, + to checkpoint every timestep on disk, use:: + + from checkpoint_schedules import SingleDiskStorageSchedule + schedule = SingleDiskStorageSchedule() + + `checkpoint_schedules` provides other schedules for checkpointing to memory, disk, + or a combination of both. Parameters ---------- @@ -70,8 +85,6 @@ def enable_disk_checkpointing(dirname=None, comm=COMM_WORLD, cleanup=True): tape = get_working_tape() if "firedrake" not in tape._package_data: tape._package_data["firedrake"] = DiskCheckpointer(dirname, comm, cleanup) - if not disk_checkpointing(): - continue_disk_checkpointing() def disk_checkpointing(): @@ -118,7 +131,7 @@ def __del__(self): class DiskCheckpointer(TapePackageData): - """Manger for the disk checkpointing process. + """Manager for the disk checkpointing process. Parameters ---------- @@ -278,7 +291,7 @@ class CheckpointFunction(CheckpointBase, OverloadedType): def __init__(self, function): from firedrake.checkpointing import CheckpointFile - self.name = function.name + self.name = function.name() self.mesh = function.function_space().mesh() self.file = current_checkpoint_file() @@ -310,7 +323,7 @@ def restore(self): function = infile.load_function(self.mesh, self.stored_name, idx=self.stored_index) return type(function)(function.function_space(), - function.dat, name=self.name(), count=self.count) + function.dat, name=self.name, count=self.count) def _ad_restore_at_checkpoint(self, checkpoint): return checkpoint.restore() diff --git a/tests/firedrake/adjoint/test_disk_checkpointing.py b/tests/firedrake/adjoint/test_disk_checkpointing.py index fc59801d83..3edcf904d5 100644 --- a/tests/firedrake/adjoint/test_disk_checkpointing.py +++ b/tests/firedrake/adjoint/test_disk_checkpointing.py @@ -4,9 +4,9 @@ from firedrake.__future__ import * from firedrake.adjoint import * from firedrake.adjoint_utils.checkpointing import disk_checkpointing +from checkpoint_schedules import SingleDiskStorageSchedule import numpy as np import os -from checkpoint_schedules import SingleDiskStorageSchedule @pytest.fixture(autouse=True) @@ -30,7 +30,7 @@ def handle_annotation(): pause_disk_checkpointing() -def adjoint_example(fine, coarse): +def adjoint_example(fine, coarse=None): dg_space = FunctionSpace(fine, "DG", 1) cg_space = FunctionSpace(fine, "CG", 2) W = dg_space * cg_space @@ -46,16 +46,18 @@ def adjoint_example(fine, coarse): v.assign(m) # FunctionSplitBlock, GenericSolveBlock u.project(v) + if coarse: + dg_space_c = FunctionSpace(coarse, "DG", 1) + cg_space_c = FunctionSpace(coarse, "CG", 2) - dg_space_c = FunctionSpace(coarse, "DG", 1) - cg_space_c = FunctionSpace(coarse, "CG", 2) - - # SupermeshProjectBlock - u_c = project(u, dg_space_c) - v_c = project(v, cg_space_c) + # SupermeshProjectBlock + u_c = project(u, dg_space_c) + v_c = project(v, cg_space_c) - # AssembleBlock - J = assemble((u_c - v_c)**2 * dx) + # AssembleBlock + J = assemble((u_c - v_c)**2 * dx) + else: + J = assemble((u - v)**2 * dx) Jhat = ReducedFunctional(J, Control(m)) @@ -77,36 +79,66 @@ def adjoint_example(fine, coarse): @pytest.mark.skipcomplex -@pytest.mark.parametrize("checkpoint_schedule", [True, False]) -def test_disk_checkpointing(checkpoint_schedule): +def test_disk_checkpointing(): # Use a Firedrake Tape subclass that supports disk checkpointing. set_working_tape(Tape()) tape = get_working_tape() tape.clear_tape() enable_disk_checkpointing() - if checkpoint_schedule: - tape.enable_checkpointing(SingleDiskStorageSchedule()) + tape.enable_checkpointing(SingleDiskStorageSchedule()) fine = checkpointable_mesh(UnitSquareMesh(10, 10, name="fine")) - coarse = checkpointable_mesh(UnitSquareMesh(4, 4, name="coarse")) - J_disk, grad_J_disk = adjoint_example(fine, coarse) + coarse = checkpointable_mesh(UnitSquareMesh(6, 6, name="coarse")) + J_disk, grad_J_disk = adjoint_example(fine, coarse=coarse) + + assert disk_checkpointing() is False - if checkpoint_schedule: - assert disk_checkpointing() is False tape.clear_tape() - if not checkpoint_schedule: - pause_disk_checkpointing() - J_mem, grad_J_mem = adjoint_example(fine, coarse) + J_mem, grad_J_mem = adjoint_example(fine, coarse=coarse) + + assert np.allclose(J_disk, J_mem) + assert np.allclose(assemble((grad_J_disk - grad_J_mem)**2*dx), 0.0) + + +@pytest.mark.skipcomplex +@pytest.mark.parallel(nprocs=3) +def test_disk_checkpointing_parallel(): + # Use a Firedrake Tape subclass that supports disk checkpointing. + set_working_tape(Tape()) + tape = get_working_tape() + tape.clear_tape() + continue_annotation() + enable_disk_checkpointing() + tape.enable_checkpointing(SingleDiskStorageSchedule()) + mesh = checkpointable_mesh(UnitSquareMesh(10, 10)) + J_disk, grad_J_disk = adjoint_example(mesh) + assert disk_checkpointing() is False + tape.clear_tape() + J_mem, grad_J_mem = adjoint_example(mesh) assert np.allclose(J_disk, J_mem) assert np.allclose(assemble((grad_J_disk - grad_J_mem)**2*dx), 0.0) @pytest.mark.skipcomplex -def test_disk_checkpointing_error(): +def test_disk_checkpointing_successive_writes(): + from firedrake.adjoint import checkpointable_mesh tape = get_working_tape() - # check the raise of the exception - with pytest.raises(RuntimeError): - tape.enable_checkpointing(SingleDiskStorageSchedule()) - assert disk_checkpointing_callback["firedrake"] == "Please call enable_disk_checkpointing() "\ - "before checkpointing on the disk." + tape.clear_tape() + enable_disk_checkpointing() + tape.enable_checkpointing(SingleDiskStorageSchedule()) + + mesh = checkpointable_mesh(UnitSquareMesh(1, 1)) + + cg_space = FunctionSpace(mesh, "CG", 1) + u = Function(cg_space, name='u') + v = Function(cg_space, name='v') + + u.assign(1.) + v.assign(v + 2.*u) + v.assign(v + 3.*u) + + J = assemble(v*dx) + Jhat = ReducedFunctional(J, Control(u)) + assert np.allclose(J, Jhat(Function(cg_space).interpolate(1.))) + assert disk_checkpointing() is False diff --git a/tests/firedrake/output/test_adjoint_disk_checkpointing.py b/tests/firedrake/output/test_adjoint_disk_checkpointing.py deleted file mode 100644 index a492f568e9..0000000000 --- a/tests/firedrake/output/test_adjoint_disk_checkpointing.py +++ /dev/null @@ -1,106 +0,0 @@ -from firedrake import * -from firedrake.__future__ import * -from pyadjoint import (ReducedFunctional, get_working_tape, stop_annotating, - pause_annotation, Control) -import numpy as np -import os -import pytest - - -@pytest.fixture(autouse=True, scope="module") -def handle_annotation(): - from firedrake.adjoint import annotate_tape, continue_annotation - if not annotate_tape(): - continue_annotation() - yield - # Ensure annotations are paused at the end of the module. - annotate = annotate_tape() - if annotate: - pause_annotation() - - -def adjoint_example(mesh): - # This example is designed to exercise all the block types for which - # the disk checkpointer does something. - dg_space = FunctionSpace(mesh, "DG", 1) - cg_space = FunctionSpace(mesh, "CG", 2) - W = dg_space * cg_space - - w = Function(W) - - x, y = SpatialCoordinate(mesh) - # AssembleBlock - m = assemble(interpolate(sin(4*pi*x)*cos(4*pi*y), cg_space)) - - u, v = w.subfunctions - # FunctionAssignBlock, FunctionMergeBlock - v.assign(m) - # SubfunctionBlock, GenericSolveBlock - u.project(v) - - # AssembleBlock - J = assemble((u - v)**2 * dx) - Jhat = ReducedFunctional(J, Control(m)) - - with stop_annotating(): - m_new = assemble(interpolate(sin(4*pi*x)*cos(4*pi*y), cg_space)) - checkpointer = get_working_tape()._package_data - init_file_timestamp = os.stat(checkpointer["firedrake"].init_checkpoint_file.name).st_mtime - current_file_timestamp = os.stat(checkpointer["firedrake"].current_checkpoint_file.name).st_mtime - Jnew = Jhat(m_new) - # Check that any new disk checkpoints are written to the correct file. - assert init_file_timestamp == os.stat(checkpointer["firedrake"].init_checkpoint_file.name).st_mtime - assert current_file_timestamp < os.stat(checkpointer["firedrake"].current_checkpoint_file.name).st_mtime - - assert np.allclose(J, Jnew) - - grad_Jnew = Jhat.derivative() - - return Jnew, grad_Jnew - - -@pytest.mark.skipcomplex -# A serial version of this test is included in the pyadjoint tests. -@pytest.mark.parallel(nprocs=3) -def test_disk_checkpointing(): - from firedrake.adjoint import enable_disk_checkpointing, \ - checkpointable_mesh, pause_disk_checkpointing, continue_annotation - tape = get_working_tape() - tape.clear_tape() - enable_disk_checkpointing() - continue_annotation() - mesh = checkpointable_mesh(UnitSquareMesh(10, 10, name="mesh")) - J_disk, grad_J_disk = adjoint_example(mesh) - tape.clear_tape() - pause_disk_checkpointing() - - J_mem, grad_J_mem = adjoint_example(mesh) - - assert np.allclose(J_disk, J_mem) - assert np.allclose(assemble((grad_J_disk - grad_J_mem)**2*dx), 0.0) - tape.clear_tape() - - -@pytest.mark.skipcomplex -def test_disk_checkpointing_successive_writes(): - from firedrake.adjoint import enable_disk_checkpointing, \ - checkpointable_mesh, pause_disk_checkpointing - tape = get_working_tape() - tape.clear_tape() - enable_disk_checkpointing() - - mesh = checkpointable_mesh(UnitSquareMesh(1, 1)) - - cg_space = FunctionSpace(mesh, "CG", 1) - u = Function(cg_space, name='u') - v = Function(cg_space, name='v') - - u.assign(1.) - v.assign(v + 2.*u) - v.assign(v + 3.*u) - - J = assemble(v*dx) - Jhat = ReducedFunctional(J, Control(u)) - assert np.allclose(J, Jhat(Function(cg_space).interpolate(1.))) - pause_disk_checkpointing() - tape.clear_tape()