Skip to content

Commit 85467c9

Browse files
committed
Fixture to clean up generated data files
Avoids modifying the scripts directly, meaning the scripts stay portable xref #1764
1 parent 641fbce commit 85467c9

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

docs/examples/example_mitgcm.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from datetime import timedelta
2+
from pathlib import Path
3+
from typing import Literal
24

35
import numpy as np
46
import parcels
@@ -7,7 +9,7 @@
79
ptype = {"scipy": parcels.ScipyParticle, "jit": parcels.JITParticle}
810

911

10-
def run_mitgcm_zonally_reentrant(mode):
12+
def run_mitgcm_zonally_reentrant(mode: Literal["scipy", "jit"], path: Path):
1113
"""Function that shows how to load MITgcm data in a zonally periodic domain."""
1214
data_folder = parcels.download_example_dataset("MITgcm_example_data")
1315
filenames = {
@@ -41,7 +43,7 @@ def periodicBC(particle, fieldset, time):
4143
size=10,
4244
)
4345
pfile = parcels.ParticleFile(
44-
"MIT_particles_" + str(mode) + ".zarr",
46+
str(path),
4547
pset,
4648
outputdt=timedelta(days=1),
4749
chunks=(len(pset), 1),
@@ -52,12 +54,15 @@ def periodicBC(particle, fieldset, time):
5254
)
5355

5456

55-
def test_mitgcm_output_compare():
56-
run_mitgcm_zonally_reentrant("scipy")
57-
run_mitgcm_zonally_reentrant("jit")
57+
def test_mitgcm_output_compare(tmpdir):
58+
def get_path(mode: Literal["scipy", "jit"]) -> Path:
59+
return tmpdir / f"MIT_particles_{mode}.zarr"
5860

59-
ds_jit = xr.open_zarr("MIT_particles_jit.zarr")
60-
ds_scipy = xr.open_zarr("MIT_particles_scipy.zarr")
61+
for mode in ["scipy", "jit"]:
62+
run_mitgcm_zonally_reentrant(mode, get_path(mode))
63+
64+
ds_jit = xr.open_zarr(get_path("jit"))
65+
ds_scipy = xr.open_zarr(get_path("scipy"))
6166

6267
np.testing.assert_allclose(ds_jit.lat.data, ds_scipy.lat.data)
6368
np.testing.assert_allclose(ds_jit.lon.data, ds_scipy.lon.data)

tests/test_examples.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import os
12
import runpy
3+
import shutil
24
import sys
5+
import time
36
from pathlib import Path
47

58
import pytest
@@ -8,6 +11,33 @@
811
example_fnames = [path.name for path in example_folder.glob("*.py")]
912

1013

14+
@pytest.fixture(autouse=True)
15+
def cleanup_generated_data_files():
16+
"""Clean up generated data files from test run.
17+
18+
Records current folder contents before test, and cleans up any generated `.nc` files
19+
and `.zarr` folders afterwards. For safety this is non-recursive. This function is
20+
only necessary as the scripts being run aren't native pytest tests, so they don't
21+
have access to the `tmpdir` fixture.
22+
23+
"""
24+
folder_contents = os.listdir()
25+
yield
26+
time.sleep(0.1) # Buffer so that files are closed before we try to delete them.
27+
for fname in os.listdir():
28+
if fname in folder_contents:
29+
continue
30+
if not (fname.endswith(".nc") or fname.endswith(".zarr")):
31+
continue
32+
33+
path = Path(fname)
34+
if path.is_dir():
35+
shutil.rmtree(path)
36+
else:
37+
path.unlink()
38+
print(f"Removed {path}")
39+
40+
1141
@pytest.mark.parametrize("example_fname", example_fnames)
1242
def test_example_script(example_fname):
1343
script = str(example_folder / example_fname)

0 commit comments

Comments
 (0)