Skip to content

Commit

Permalink
Fixture to clean up generated data files
Browse files Browse the repository at this point in the history
Avoids modifying the scripts directly, meaning the scripts stay portable

xref #1764
  • Loading branch information
VeckoTheGecko committed Nov 25, 2024
1 parent 641fbce commit 85467c9
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
19 changes: 12 additions & 7 deletions docs/examples/example_mitgcm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import timedelta
from pathlib import Path
from typing import Literal

import numpy as np
import parcels
Expand All @@ -7,7 +9,7 @@
ptype = {"scipy": parcels.ScipyParticle, "jit": parcels.JITParticle}


def run_mitgcm_zonally_reentrant(mode):
def run_mitgcm_zonally_reentrant(mode: Literal["scipy", "jit"], path: Path):
"""Function that shows how to load MITgcm data in a zonally periodic domain."""
data_folder = parcels.download_example_dataset("MITgcm_example_data")
filenames = {
Expand Down Expand Up @@ -41,7 +43,7 @@ def periodicBC(particle, fieldset, time):
size=10,
)
pfile = parcels.ParticleFile(
"MIT_particles_" + str(mode) + ".zarr",
str(path),
pset,
outputdt=timedelta(days=1),
chunks=(len(pset), 1),
Expand All @@ -52,12 +54,15 @@ def periodicBC(particle, fieldset, time):
)


def test_mitgcm_output_compare():
run_mitgcm_zonally_reentrant("scipy")
run_mitgcm_zonally_reentrant("jit")
def test_mitgcm_output_compare(tmpdir):
def get_path(mode: Literal["scipy", "jit"]) -> Path:
return tmpdir / f"MIT_particles_{mode}.zarr"

ds_jit = xr.open_zarr("MIT_particles_jit.zarr")
ds_scipy = xr.open_zarr("MIT_particles_scipy.zarr")
for mode in ["scipy", "jit"]:
run_mitgcm_zonally_reentrant(mode, get_path(mode))

ds_jit = xr.open_zarr(get_path("jit"))
ds_scipy = xr.open_zarr(get_path("scipy"))

np.testing.assert_allclose(ds_jit.lat.data, ds_scipy.lat.data)
np.testing.assert_allclose(ds_jit.lon.data, ds_scipy.lon.data)
30 changes: 30 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import runpy
import shutil
import sys
import time
from pathlib import Path

import pytest
Expand All @@ -8,6 +11,33 @@
example_fnames = [path.name for path in example_folder.glob("*.py")]


@pytest.fixture(autouse=True)
def cleanup_generated_data_files():
"""Clean up generated data files from test run.
Records current folder contents before test, and cleans up any generated `.nc` files
and `.zarr` folders afterwards. For safety this is non-recursive. This function is
only necessary as the scripts being run aren't native pytest tests, so they don't
have access to the `tmpdir` fixture.
"""
folder_contents = os.listdir()
yield
time.sleep(0.1) # Buffer so that files are closed before we try to delete them.
for fname in os.listdir():
if fname in folder_contents:
continue
if not (fname.endswith(".nc") or fname.endswith(".zarr")):
continue

Check warning on line 31 in tests/test_examples.py

View check run for this annotation

Codecov / codecov/patch

tests/test_examples.py#L31

Added line #L31 was not covered by tests

path = Path(fname)
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()
print(f"Removed {path}")


@pytest.mark.parametrize("example_fname", example_fnames)
def test_example_script(example_fname):
script = str(example_folder / example_fname)
Expand Down

0 comments on commit 85467c9

Please sign in to comment.