Skip to content

Commit

Permalink
Fix disk heckpointing test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci authored and JDBetteridge committed Sep 12, 2024
1 parent 0d77d89 commit 6936ca4
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions tests/output/test_adjoint_disk_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,17 @@ def adjoint_example(mesh):

# AssembleBlock
J = assemble((u - v)**2 * dx)

Jhat = ReducedFunctional(J, Control(m))

with stop_annotating():
m_new = assemble(interpolate(sin(4*pi*x)*cos(4*pi*y), cg_space))
checkpointer = get_working_tape()._checkpoint_metadata
init_file_timestamp = os.stat(checkpointer.init_checkpoint_file).st_mtime
current_file_timestamp = os.stat(checkpointer.current_checkpoint_file).st_mtime
checkpointer = get_working_tape()._package_data
init_file_timestamp = os.stat(checkpointer["firedrake"].init_checkpoint_file.name).st_mtime
current_file_timestamp = os.stat(checkpointer["firedrake"].current_checkpoint_file.name).st_mtime
Jnew = Jhat(m_new)
# Check that any new disk checkpoints are written to the correct file.
assert init_file_timestamp == os.stat(checkpointer.init_checkpoint_file).st_mtime
assert current_file_timestamp < os.stat(checkpointer.current_checkpoint_file).st_mtime
assert init_file_timestamp == os.stat(checkpointer["firedrake"].init_checkpoint_file.name).st_mtime
assert current_file_timestamp < os.stat(checkpointer["firedrake"].current_checkpoint_file.name).st_mtime

assert np.allclose(J, Jnew)

Expand All @@ -60,20 +59,16 @@ def adjoint_example(mesh):
return Jnew, grad_Jnew


@pytest.mark.broken
# This test pollutes the tape!!!
@pytest.mark.skipcomplex
# Waiting on stable parallel decompositions through disk checkpointing.
@pytest.mark.xfail
# A serial version of this test is included in the pyadjoint tests.
@pytest.mark.parallel(nprocs=3)
def test_disk_checkpointing():
from firedrake.adjoint import enable_disk_checkpointing, \
checkpointable_mesh, pause_disk_checkpointing
checkpointable_mesh, pause_disk_checkpointing, continue_annotation
tape = get_working_tape()
tape.clear_tape()
enable_disk_checkpointing()

continue_annotation()
mesh = checkpointable_mesh(UnitSquareMesh(10, 10, name="mesh"))
J_disk, grad_J_disk = adjoint_example(mesh)
tape.clear_tape()
Expand Down Expand Up @@ -109,3 +104,7 @@ def test_disk_checkpointing_successive_writes():
assert np.allclose(J, Jhat(1))
pause_disk_checkpointing()
tape.clear_tape()


if __name__ == "__main__":
test_disk_checkpointing()

0 comments on commit 6936ca4

Please sign in to comment.