diff --git a/tests/output/test_adjoint_disk_checkpointing.py b/tests/output/test_adjoint_disk_checkpointing.py index 66227677b9..ee76f3cec2 100644 --- a/tests/output/test_adjoint_disk_checkpointing.py +++ b/tests/output/test_adjoint_disk_checkpointing.py @@ -40,18 +40,17 @@ def adjoint_example(mesh): # AssembleBlock J = assemble((u - v)**2 * dx) - Jhat = ReducedFunctional(J, Control(m)) with stop_annotating(): m_new = assemble(interpolate(sin(4*pi*x)*cos(4*pi*y), cg_space)) - checkpointer = get_working_tape()._checkpoint_metadata - init_file_timestamp = os.stat(checkpointer.init_checkpoint_file).st_mtime - current_file_timestamp = os.stat(checkpointer.current_checkpoint_file).st_mtime + checkpointer = get_working_tape()._package_data + init_file_timestamp = os.stat(checkpointer["firedrake"].init_checkpoint_file.name).st_mtime + current_file_timestamp = os.stat(checkpointer["firedrake"].current_checkpoint_file.name).st_mtime Jnew = Jhat(m_new) # Check that any new disk checkpoints are written to the correct file. - assert init_file_timestamp == os.stat(checkpointer.init_checkpoint_file).st_mtime - assert current_file_timestamp < os.stat(checkpointer.current_checkpoint_file).st_mtime + assert init_file_timestamp == os.stat(checkpointer["firedrake"].init_checkpoint_file.name).st_mtime + assert current_file_timestamp < os.stat(checkpointer["firedrake"].current_checkpoint_file.name).st_mtime assert np.allclose(J, Jnew) @@ -60,20 +59,16 @@ def adjoint_example(mesh): return Jnew, grad_Jnew -@pytest.mark.broken -# This test pollutes the tape!!! @pytest.mark.skipcomplex -# Waiting on stable parallel decompositions through disk checkpointing. -@pytest.mark.xfail # A serial version of this test is included in the pyadjoint tests. @pytest.mark.parallel(nprocs=3) def test_disk_checkpointing(): from firedrake.adjoint import enable_disk_checkpointing, \ - checkpointable_mesh, pause_disk_checkpointing + checkpointable_mesh, pause_disk_checkpointing, continue_annotation tape = get_working_tape() tape.clear_tape() enable_disk_checkpointing() - + continue_annotation() mesh = checkpointable_mesh(UnitSquareMesh(10, 10, name="mesh")) J_disk, grad_J_disk = adjoint_example(mesh) tape.clear_tape() @@ -109,3 +104,7 @@ def test_disk_checkpointing_successive_writes(): assert np.allclose(J, Jhat(1)) pause_disk_checkpointing() tape.clear_tape() + + +if __name__ == "__main__": + test_disk_checkpointing() \ No newline at end of file