From 5a5a699cc324ce532d199e8bbfe1d1a8cfa67f24 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 23 Aug 2024 13:18:45 +0200 Subject: [PATCH 1/9] Add typechecking workflows and add type annotations --- .github/workflows/ci.yml | 22 ++++++++ .pre-commit-config.yaml | 13 +++++ README.md | 2 +- environment.yml | 8 ++- parcels/_typing.py | 46 +++++++++++++++++ parcels/compilation/codegenerator.py | 6 +-- parcels/interaction/interactionkernel.py | 2 +- parcels/tools/converters.py | 8 +-- parcels/tools/global_statics.py | 2 +- parcels/tools/interpolation_utils.py | 66 ++++++++++++++++-------- parcels/tools/loggers.py | 4 +- parcels/tools/timer.py | 2 +- pyproject.toml | 22 ++++++++ 13 files changed, 168 insertions(+), 35 deletions(-) create mode 100644 parcels/_typing.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c0a62e3a..41b91bfe8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -81,3 +81,25 @@ jobs: with: name: Integration test report path: ${{ matrix.os }}_integration_test_report.html + typechecking: + name: mypy + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Conda and parcels + uses: ./.github/actions/install-parcels + with: + environment-file: environment.yml + - run: conda install lxml # dep for report generation + - name: Typechecking + run: | + mypy --install-types --non-interactive parcels --cobertura-xml-report mypy_report + - name: Upload mypy coverage to Codecov + uses: codecov/codecov-action@v3.1.1 + if: ${{ always() }} # Upload even on error of mypy + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: mypy_report/cobertura.xml + flags: mypy + fail_ci_if_error: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccc8a6e5f..eeae58c68 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,3 +24,16 @@ repos: rev: v0.4.0 hooks: - id: biome-format + + # Ruff doesn't have full coverage of pydoclint https://github.com/astral-sh/ruff/issues/12434 + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + name: pydoclint + files: 'none' + # files: parcels/fieldset.py # put here instead of in config file due to https://github.com/pre-commit/pre-commit-hooks/issues/112#issuecomment-215613842 + args: + - --select=DOC103 # TODO: Expand coverage to other codes + additional_dependencies: + - pydoclint[flake8] diff --git a/README.md b/README.md index 10e6ae196..511f8e5fc 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ ## Parcels [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/OceanParcels/parcels/master?labpath=docs%2Fexamples%2Fparcels_tutorial.ipynb) -[![unit-tests](https://github.com/OceanParcels/parcels/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/OceanParcels/parcels/actions/workflows/unit-tests.yml) +[![unit-tests](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml/badge.svg)](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml) [![codecov](https://codecov.io/gh/OceanParcels/parcels/branch/master/graph/badge.svg)](https://codecov.io/gh/OceanParcels/parcels) [![Anaconda-release](https://anaconda.org/conda-forge/parcels/badges/version.svg)](https://anaconda.org/conda-forge/parcels/) [![Anaconda-date](https://anaconda.org/conda-forge/parcels/badges/latest_release_date.svg)](https://anaconda.org/conda-forge/parcels/) diff --git a/environment.yml b/environment.yml index e244acca1..d096b3c3a 100644 --- a/environment.yml +++ b/environment.yml @@ -33,10 +33,14 @@ dependencies: - pytest-html - coverage + # Typing + - mypy + - types-tqdm + - types-psutil + + # Linting - - flake8>=2.1.0 - pre_commit - - pydocstyle # Docs - ipython diff --git a/parcels/_typing.py b/parcels/_typing.py new file mode 100644 index 000000000..637393a07 --- /dev/null +++ b/parcels/_typing.py @@ -0,0 +1,46 @@ +""" +Typing support for Parcels. + +This module contains type aliases used throughout Parcels as well as functions that are +used for runtime parameter validation (to ensure users are only using the right params). + +""" + +import ast +import datetime +import os +from typing import Any, Callable, Literal, get_args + + +class ParcelsAST(ast.AST): + ccode: str + + +# InterpMethod = InterpMethodOption | dict[str, InterpMethodOption] # (can also be a dict, search for `if type(interp_method) is dict`) +# InterpMethodOption = Literal[ +# "nearest", +# "freeslip", +# "partialslip", +# "bgrid_velocity", +# "bgrid_w_velocity", +# "cgrid_velocity", +# "linear_invdist_land_tracer", +# "nearest", +# "cgrid_tracer", +# ] # mostly corresponds with `interp_method` # TODO: This should be narrowed. Unlikely applies to every context +PathLike = str | os.PathLike +Mesh = Literal["spherical", "flat"] # mostly corresponds with `mesh` +VectorType = Literal["3D", "2D"] | None # mostly corresponds with `vector_type` +ChunkMode = Literal["auto", "specific", "failsafe"] # mostly corresponds with `chunk_mode` +GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # mostly corresponds with `grid_indexing_type` +UpdateStatus = Literal["not_updated", "first_updated", "updated"] # mostly corresponds with `update_status` +TimePeriodic = float | datetime.timedelta | Literal[False] # mostly corresponds with `update_status` + +KernelFunction = Callable[..., None] + + +def ensure_is_literal_value(value: Any, literal: Any) -> None: + """Ensures that a value is a valid option for the provided Literal type annotation.""" + valid_options = get_args(literal) + if value not in valid_options: + raise ValueError(f"{value!r} is not a valid option. Valid options are {valid_options}") diff --git a/parcels/compilation/codegenerator.py b/parcels/compilation/codegenerator.py index 2569dcc6c..921b422cb 100644 --- a/parcels/compilation/codegenerator.py +++ b/parcels/compilation/codegenerator.py @@ -410,7 +410,7 @@ class KernelGenerator(ABC, ast.NodeVisitor): # Intrinsic variables that appear as function arguments kernel_vars = ["particle", "fieldset", "time", "output_time", "tol"] - array_vars = [] + array_vars: list[str] = [] def __init__(self, fieldset=None, ptype=JITParticle): self.fieldset = fieldset @@ -419,7 +419,7 @@ def __init__(self, fieldset=None, ptype=JITParticle): self.vector_field_args = collections.OrderedDict() self.const_args = collections.OrderedDict() - def generate(self, py_ast, funcvars): + def generate(self, py_ast, funcvars: list[str]): # Replace occurrences of intrinsic objects in Python AST transformer = IntrinsicTransformer(self.fieldset, self.ptype) py_ast = transformer.visit(py_ast) @@ -434,7 +434,7 @@ def generate(self, py_ast, funcvars): # Insert variable declarations for non-intrinsic variables # Make sure that repeated variables are not declared more than # once. If variables occur in multiple Kernels, give a warning - used_vars = [] + used_vars: list[str] = [] funcvars_copy = copy(funcvars) # editing a list while looping over it is dangerous for kvar in funcvars: if kvar in used_vars + ["particle_dlon", "particle_dlat", "particle_ddepth"]: diff --git a/parcels/interaction/interactionkernel.py b/parcels/interaction/interactionkernel.py index db4c7f897..c07af94d6 100644 --- a/parcels/interaction/interactionkernel.py +++ b/parcels/interaction/interactionkernel.py @@ -36,7 +36,7 @@ def __init__( py_ast=None, funcvars=None, c_include="", - delete_cfiles=True, + delete_cfiles: bool = True, ): if MPI is not None and MPI.COMM_WORLD.Get_size() > 1: raise NotImplementedError( diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index c8ccf8003..17b7a5abf 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -2,10 +2,12 @@ import inspect from datetime import timedelta from math import cos, pi +from typing import Any import cftime import numpy as np import xarray as xr +from numpy.typing import ArrayLike, NDArray __all__ = [ "UnitConverter", @@ -20,7 +22,7 @@ ] -def convert_to_flat_array(var): +def convert_to_flat_array(var: list[float] | float | int | NDArray[Any] | ArrayLike) -> NDArray[Any]: """Convert lists and single integers/floats to one-dimensional numpy arrays Parameters @@ -167,8 +169,8 @@ def __le__(self, other): class UnitConverter: """Interface class for spatial unit conversion during field sampling that performs no conversion.""" - source_unit = None - target_unit = None + source_unit: str | None = None + target_unit: str | None = None def to_target(self, value, x, y, z): return value diff --git a/parcels/tools/global_statics.py b/parcels/tools/global_statics.py index 0e97bac0d..896d3195f 100644 --- a/parcels/tools/global_statics.py +++ b/parcels/tools/global_statics.py @@ -8,7 +8,7 @@ from os import getuid except: # Windows does not have getuid(), so define to simply return 'tmp' - def getuid(): + def getuid(): # type: ignore return "tmp" diff --git a/parcels/tools/interpolation_utils.py b/parcels/tools/interpolation_utils.py index 27ab24c7f..273dbbec8 100644 --- a/parcels/tools/interpolation_utils.py +++ b/parcels/tools/interpolation_utils.py @@ -1,17 +1,19 @@ +from typing import Callable, Literal + import numpy as np -__all__ = [] +from parcels._typing import Mesh +__all__ = [] # type: ignore -# fmt: off -def phi1D_lin(xsi): - phi = [1-xsi, - xsi] +def phi1D_lin(xsi: float) -> list[float]: + phi = [1 - xsi, xsi] return phi -def phi1D_quad(xsi): +# fmt: off +def phi1D_quad(xsi: float) -> list[float]: phi = [2*xsi**2-3*xsi+1, -4*xsi**2+4*xsi, 2*xsi**2-xsi] @@ -19,7 +21,8 @@ def phi1D_quad(xsi): return phi -def phi2D_lin(xsi, eta): + +def phi2D_lin(xsi: float, eta: float) -> list[float]: phi = [(1-xsi) * (1-eta), xsi * (1-eta), xsi * eta , @@ -28,7 +31,7 @@ def phi2D_lin(xsi, eta): return phi -def phi3D_lin(xsi, eta, zet): +def phi3D_lin(xsi: float, eta: float, zet: float) -> list[float]: phi = [(1-xsi) * (1-eta) * (1-zet), xsi * (1-eta) * (1-zet), xsi * eta * (1-zet), @@ -41,7 +44,7 @@ def phi3D_lin(xsi, eta, zet): return phi -def dphidxsi3D_lin(xsi, eta, zet): +def dphidxsi3D_lin(xsi: float, eta: float, zet: float) -> tuple[list[float], list[float], list[float]]: dphidxsi = [ - (1-eta) * (1-zet), (1-eta) * (1-zet), ( eta) * (1-zet), @@ -70,7 +73,9 @@ def dphidxsi3D_lin(xsi, eta, zet): return dphidxsi, dphideta, dphidzet -def dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh): +def dxdxsi3D_lin( + hexa_x: list[float], hexa_y: list[float], hexa_z: list[float], xsi: float, eta: float, zet: float, mesh: Mesh +) -> tuple[float, float, float, float, float, float, float, float, float]: dphidxsi, dphideta, dphidzet = dphidxsi3D_lin(xsi, eta, zet) if mesh == 'spherical': @@ -99,16 +104,29 @@ def dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh): return dxdxsi, dxdeta, dxdzet, dydxsi, dydeta, dydzet, dzdxsi, dzdeta, dzdzet -def jacobian3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh): +def jacobian3D_lin( + hexa_x: list[float], hexa_y: list[float], hexa_z: list[float], xsi: float, eta: float, zet: float, mesh: Mesh +) -> float: dxdxsi, dxdeta, dxdzet, dydxsi, dydeta, dydzet, dzdxsi, dzdeta, dzdzet = dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh) - jac = dxdxsi * (dydeta*dzdzet - dzdeta*dydzet)\ - - dxdeta * (dydxsi*dzdzet - dzdxsi*dydzet)\ - + dxdzet * (dydxsi*dzdeta - dzdxsi*dydeta) + jac = ( + dxdxsi * (dydeta * dzdzet - dzdeta * dydzet) + - dxdeta * (dydxsi * dzdzet - dzdxsi * dydzet) + + dxdzet * (dydxsi * dzdeta - dzdxsi * dydeta) + ) return jac -def jacobian3D_lin_face(hexa_x, hexa_y, hexa_z, xsi, eta, zet, orientation, mesh): +def jacobian3D_lin_face( + hexa_x: list[float], + hexa_y: list[float], + hexa_z: list[float], + xsi: float, + eta: float, + zet: float, + orientation: Literal["zonal", "meridional", "vertical"], + mesh: Mesh, +) -> float: dxdxsi, dxdeta, dxdzet, dydxsi, dydeta, dydzet, dzdxsi, dzdeta, dzdzet = dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh) if orientation == 'zonal': @@ -128,7 +146,7 @@ def jacobian3D_lin_face(hexa_x, hexa_y, hexa_z, xsi, eta, zet, orientation, mesh return jac -def dphidxsi2D_lin(xsi, eta): +def dphidxsi2D_lin(xsi: float, eta: float) -> tuple[list[float], list[float]]: dphidxsi = [-(1-eta), 1-eta, eta, @@ -141,7 +159,12 @@ def dphidxsi2D_lin(xsi, eta): return dphidxsi, dphideta -def dxdxsi2D_lin(quad_x, quad_y, xsi, eta,): +def dxdxsi2D_lin( + quad_x, + quad_y, + xsi: float, + eta: float, +): dphidxsi, dphideta = dphidxsi2D_lin(xsi, eta) dxdxsi = np.dot(quad_x, dphidxsi) @@ -152,20 +175,21 @@ def dxdxsi2D_lin(quad_x, quad_y, xsi, eta,): return dxdxsi, dxdeta, dydxsi, dydeta -def jacobian2D_lin(quad_x, quad_y, xsi, eta): +def jacobian2D_lin(quad_x, quad_y, xsi: float, eta: float): dxdxsi, dxdeta, dydxsi, dydeta = dxdxsi2D_lin(quad_x, quad_y, xsi, eta) - jac = dxdxsi*dydeta - dxdeta*dydxsi + jac = dxdxsi * dydeta - dxdeta * dydxsi return jac def length2d_lin_edge(quad_x, quad_y, ids): xe = [quad_x[ids[0]], quad_x[ids[1]]] ye = [quad_y[ids[0]], quad_y[ids[1]]] - return np.sqrt((xe[1]-xe[0])**2+(ye[1]-ye[0])**2) + return np.sqrt((xe[1] - xe[0]) ** 2 + (ye[1] - ye[0]) ** 2) -def interpolate(phi, f, xsi): +def interpolate(phi: Callable[[float], list[float]], f: list[float], xsi: float) -> float: return np.dot(phi(xsi), f) + # fmt: on diff --git a/parcels/tools/loggers.py b/parcels/tools/loggers.py index b21dee81a..335bc3a8d 100644 --- a/parcels/tools/loggers.py +++ b/parcels/tools/loggers.py @@ -40,10 +40,10 @@ def info_once(self, message, *args, **kws): logger.addHandler(handler) logging.addLevelName(warning_once_level, "WARNING") -logging.Logger.warning_once = warning_once +logging.Logger.warning_once = warning_once # type: ignore logging.addLevelName(info_once_level, "INFO") -logging.Logger.info_once = info_once +logging.Logger.info_once = info_once # type: ignore dup_filter = DuplicateFilter() logger.addFilter(dup_filter) diff --git a/parcels/tools/timer.py b/parcels/tools/timer.py index 02daf75ab..6aa7701e3 100644 --- a/parcels/tools/timer.py +++ b/parcels/tools/timer.py @@ -6,7 +6,7 @@ except ModuleNotFoundError: MPI = None -__all__ = [] +__all__ = [] # type: ignore class Timer: diff --git a/pyproject.toml b/pyproject.toml index d42401fbe..613ff9978 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,3 +121,25 @@ ignore = [ [tool.ruff.lint.pydocstyle] convention = "numpy" + +[tool.mypy] +files = [ + "parcels/compilation/codegenerator.py", + "parcels/_typing.py", + "parcels/tools/*.py", + "parcels/grid.py", +] + +[[tool.mypy.overrides]] +module = [ + "parcels._version_setup", + "mpi4py", + "scipy.spatial", + "sklearn.cluster", + "zarr", + "cftime", + "pykdtree.kdtree", + "netCDF4", + "cgen" +] +ignore_missing_imports = true From 92fdf05d3c7ed9315a8b8b6d297bcd1fecaa7886 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Fri, 23 Aug 2024 15:28:51 +0200 Subject: [PATCH 2/9] Cleanup --- parcels/fieldset.py | 2 +- parcels/grid.py | 10 +++++----- parcels/tools/converters.py | 12 +++--------- tests/test_data/create_testfields.py | 1 - 4 files changed, 9 insertions(+), 16 deletions(-) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index fda0d4701..be5578ebc 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -136,7 +136,7 @@ def from_data( lat = dims["lat"] depth = np.zeros(1, dtype=np.float32) if "depth" not in dims else dims["depth"] time = np.zeros(1, dtype=np.float64) if "time" not in dims else dims["time"] - time = np.array(time) if not isinstance(time, np.ndarray) else time + time = np.array(time) if isinstance(time[0], np.datetime64): time_origin = TimeConverter(time[0]) time = np.array([time_origin.reltime(t) for t in time]) diff --git a/parcels/grid.py b/parcels/grid.py index 2ba8073c5..ecdcc2086 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -84,12 +84,12 @@ def __init__(self, lon, lat, time, time_origin, mesh): @staticmethod def create_grid(lon, lat, depth, time, time_origin, mesh, **kwargs): - if not isinstance(lon, np.ndarray): - lon = np.array(lon) - if not isinstance(lat, np.ndarray): - lat = np.array(lat) - if not (depth is None or isinstance(depth, np.ndarray)): + lon = np.array(lon) + lat = np.array(lat) + + if depth is not None: depth = np.array(depth) + if len(lon.shape) <= 1: if depth is None or len(depth.shape) <= 1: return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index 17b7a5abf..f657f593b 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -2,7 +2,6 @@ import inspect from datetime import timedelta from math import cos, pi -from typing import Any import cftime import numpy as np @@ -22,20 +21,15 @@ ] -def convert_to_flat_array(var: list[float] | float | int | NDArray[Any] | ArrayLike) -> NDArray[Any]: +def convert_to_flat_array(var: ArrayLike) -> NDArray: """Convert lists and single integers/floats to one-dimensional numpy arrays Parameters ---------- - var : np.ndarray, float or array_like + var : Array list or numeric to convert to a one-dimensional numpy array """ - if isinstance(var, np.ndarray): - return var.flatten() - elif isinstance(var, (int, float, np.float32, np.int32)): - return np.array([var]) - else: - return np.array(var) + return np.array(var).flatten() def _get_cftime_datetimes(): diff --git a/tests/test_data/create_testfields.py b/tests/test_data/create_testfields.py index 3c7e783fa..acf192158 100644 --- a/tests/test_data/create_testfields.py +++ b/tests/test_data/create_testfields.py @@ -48,7 +48,6 @@ def generate_perlin_testfield(): lon = np.linspace(-180.0, 180.0, img_shape[0], dtype=np.float32) lat = np.linspace(-90.0, 90.0, img_shape[1], dtype=np.float32) time = np.zeros(1, dtype=np.float64) - time = np.array(time) if not isinstance(time, np.ndarray) else time # Define arrays U (zonal), V (meridional), W (vertical) and P (sea # surface height) all on A-grid From 64c2e1636207a46815a28fcaa134fee48b175685 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 26 Aug 2024 16:48:12 +0200 Subject: [PATCH 3/9] Add type annotations --- parcels/_typing.py | 4 +++ parcels/field.py | 88 +++++++++++++++++++++++++++------------------ parcels/fieldset.py | 87 +++++++++++++++++++++++--------------------- parcels/grid.py | 73 +++++++++++++++++++++++++++++++------ parcels/kernel.py | 6 +++- pyproject.toml | 2 ++ 6 files changed, 172 insertions(+), 88 deletions(-) diff --git a/parcels/_typing.py b/parcels/_typing.py index 637393a07..f67192bcb 100644 --- a/parcels/_typing.py +++ b/parcels/_typing.py @@ -35,6 +35,10 @@ class ParcelsAST(ast.AST): GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # mostly corresponds with `grid_indexing_type` UpdateStatus = Literal["not_updated", "first_updated", "updated"] # mostly corresponds with `update_status` TimePeriodic = float | datetime.timedelta | Literal[False] # mostly corresponds with `update_status` +NetcdfEngine = Literal[ + "netcdf4", "xarray" +] # TODO: It seems that "scipy" is also an option (according to a docstring) but can't find mention in code. Investigate. + KernelFunction = Callable[..., None] diff --git a/parcels/field.py b/parcels/field.py index 651c62701..7495ccdd1 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -3,12 +3,14 @@ import math from ctypes import POINTER, Structure, c_float, c_int, pointer from pathlib import Path +from typing import TYPE_CHECKING, Iterable, Type import dask.array as da import numpy as np import xarray as xr import parcels.tools.interpolation_utils as i_u +from parcels._typing import GridIndexingType, Mesh, TimePeriodic, VectorType from parcels.tools.converters import ( Geographic, GeographicPolar, @@ -33,6 +35,11 @@ ) from .grid import CGrid, Grid, GridType +if TYPE_CHECKING: + from ctypes import _Pointer as PointerType + + from parcels.fieldset import FieldSet + __all__ = ["Field", "VectorField", "NestedField"] @@ -43,7 +50,7 @@ def _isParticle(key): return False -def _deal_with_errors(error, key, vector_type): +def _deal_with_errors(error, key, vector_type: VectorType): if _isParticle(key): key.state = AllParcelsErrorCodes[type(error)] elif _isParticle(key[-1]): @@ -134,14 +141,14 @@ class Field: def __init__( self, - name, + name: str | tuple[str, str], data, lon=None, lat=None, depth=None, time=None, grid=None, - mesh="flat", + mesh: Mesh = "flat", timestamps=None, fieldtype=None, transpose=False, @@ -150,9 +157,9 @@ def __init__( cast_data_dtype="float32", time_origin=None, interp_method="linear", - allow_time_extrapolation=None, - time_periodic=False, - gridindexingtype="nemo", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + gridindexingtype: GridIndexingType = "nemo", to_write=False, **kwargs, ): @@ -160,7 +167,8 @@ def __init__( self.name = name self.filebuffername = name else: - self.name, self.filebuffername = name + self.name = name[0] + self.filebuffername = name[1] self.data = data if grid: if grid.defer_load and isinstance(data, np.ndarray): @@ -203,11 +211,11 @@ def __init__( GridType.RectilinearSGrid, GridType.CurvilinearSGrid, ]: - logger.warning_once( + logger.warning_once( # type: ignore "General s-levels are not supported in B-grid. RectilinearSGrid and CurvilinearSGrid can still be used to deal with shaved cells, but the levels must be horizontal." ) - self.fieldset = None + self.fieldset: FieldSet | None = None if allow_time_extrapolation is None: self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False else: @@ -215,7 +223,7 @@ def __init__( self.time_periodic = time_periodic if self.time_periodic is not False and self.allow_time_extrapolation: - logger.warning_once( + logger.warning_once( # type: ignore "allow_time_extrapolation and time_periodic cannot be used together.\n \ allow_time_extrapolation is set to False" ) @@ -268,8 +276,8 @@ def __init__( self._field_fb_class = kwargs.pop("FieldFileBuffer", None) self.netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") self.netcdf_decodewarning = kwargs.pop("netcdf_decodewarning", True) - self.loaded_time_indices = [] - self.creation_log = kwargs.pop("creation_log", "") + self.loaded_time_indices: Iterable[int] = [] + self.creation_log: str = kwargs.pop("creation_log", "") self.chunksize = kwargs.pop("chunksize", None) self.netcdf_chunkdims_name_map = kwargs.pop("chunkdims_name_map", None) self.grid.depth_field = kwargs.pop("depth_field", None) @@ -283,10 +291,10 @@ def __init__( # (data_full_zdim = grid.zdim if no indices are used, for A- and C-grids and for some B-grids). It is used for the B-grid, # since some datasets do not provide the deeper level of data (which is ignored by the interpolation). self.data_full_zdim = kwargs.pop("data_full_zdim", None) - self.data_chunks = [] # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays - self.c_data_chunks = [] # C-pointers to the data_chunks array - self.nchunks = [] - self.chunk_set = False + self.data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays + self.c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array + self.nchunks: tuple[int, ...] = () + self.chunk_set: bool = False self.filebuffers = [None] * 2 if len(kwargs) > 0: raise SyntaxError(f'Field received an unexpected keyword argument "{list(kwargs.keys())[0]}"') @@ -349,13 +357,13 @@ def from_netcdf( dimensions, indices=None, grid=None, - mesh="spherical", + mesh: Mesh = "spherical", timestamps=None, - allow_time_extrapolation=None, - time_periodic=False, - deferred_load=True, + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + deferred_load: bool = True, **kwargs, - ): + ) -> "Field": """Create field from netCDF file. Parameters @@ -542,11 +550,11 @@ def from_netcdf( ) kwargs["dataFiles"] = dataFiles - chunksize = kwargs.get("chunksize", None) + chunksize: bool | None = kwargs.get("chunksize", None) grid.chunksize = chunksize if "time" in indices: - logger.warning_once("time dimension in indices is not necessary anymore. It is then ignored.") + logger.warning_once("time dimension in indices is not necessary anymore. It is then ignored.") # type: ignore if "full_load" in kwargs: # for backward compatibility with Parcels < v2.0.0 deferred_load = not kwargs["full_load"] @@ -554,6 +562,7 @@ def from_netcdf( if grid.time.size <= 2 or deferred_load is False: deferred_load = False + _field_fb_class: Type[DeferredDaskFileBuffer | DaskFileBuffer | DeferredNetcdfFileBuffer | NetcdfFileBuffer] if chunksize not in [False, None]: if deferred_load: _field_fb_class = DeferredDaskFileBuffer @@ -570,7 +579,7 @@ def from_netcdf( data_list = [] ti = 0 for tslice, fname in zip(grid.timeslices, data_filenames, strict=True): - with _field_fb_class( + with _field_fb_class( # type: ignore[operator] fname, dimensions, indices, @@ -637,7 +646,14 @@ def from_netcdf( @classmethod def from_xarray( - cls, da, name, dimensions, mesh="spherical", allow_time_extrapolation=None, time_periodic=False, **kwargs + cls, + da: xr.DataArray, + name: str, + dimensions, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + **kwargs, ): """Create field from xarray Variable. @@ -854,7 +870,9 @@ def search_indices_vertical_z(self, z): zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) return (zi, zeta) - def search_indices_vertical_s(self, x, y, z, xi, yi, xsi, eta, ti, time): + def search_indices_vertical_s( + self, x: float, y: float, z: float, xi: int, yi: int, xsi: float, eta: float, ti: int, time + ): grid = self.grid if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]: xsi = 1 @@ -886,7 +904,7 @@ def search_indices_vertical_s(self, x, y, z, xi, yi, xsi, eta, ti, time): + xsi * eta * grid.depth[:, yi + 1, xi + 1] + (1 - xsi) * eta * grid.depth[:, yi + 1, xi] ) - z = np.float32(z) + z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64 if depth_vector[-1] > depth_vector[0]: depth_indices = depth_vector <= z @@ -930,7 +948,7 @@ def reconnect_bnd_indices(self, xi, yi, xdim, ydim, sphere_mesh): xi = xdim - xi return xi, yi - def search_indices_rectilinear(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): + def search_indices_rectilinear(self, x: float, y: float, z: float, ti=-1, time=-1, particle=None, search2D=False): grid = self.grid if grid.xdim > 1 and (not grid.zonal_periodic): @@ -1684,18 +1702,18 @@ class VectorField: field defining the vertical component (default: None) """ - def __init__(self, name, U, V, W=None): + def __init__(self, name: str, U: Field, V: Field, W: Field | None = None): self.name = name self.U = U self.V = V self.W = W - self.vector_type = "3D" if W else "2D" + self.vector_type: VectorType = "3D" if W else "2D" self.gridindexingtype = U.gridindexingtype if self.U.interp_method == "cgrid_velocity": assert self.V.interp_method == "cgrid_velocity", "Interpolation methods of U and V are not the same." assert self._check_grid_dimensions(U.grid, V.grid), "Dimensions of U and V are not the same." - if self.vector_type == "3D": - assert self.W.interp_method == "cgrid_velocity", "Interpolation methods of U and W are not the same." + if W is not None: + assert W.interp_method == "cgrid_velocity", "Interpolation methods of U and W are not the same." assert self._check_grid_dimensions(U.grid, W.grid), "Dimensions of U and W are not the same." @staticmethod @@ -1707,7 +1725,7 @@ def _check_grid_dimensions(grid1, grid2): and np.allclose(grid1.time_full, grid2.time_full) ) - def dist(self, lon1, lon2, lat1, lat2, mesh, lat): + def dist(self, lon1: float, lon2: float, lat1: float, lat2: float, mesh: Mesh, lat: float): if mesh == "spherical": rad = np.pi / 180.0 deg2m = 1852 * 60.0 @@ -1715,7 +1733,7 @@ def dist(self, lon1, lon2, lat1, lat2, mesh, lat): else: return np.sqrt((lon2 - lon1) ** 2 + (lat2 - lat1) ** 2) - def jacobian(self, xsi, eta, px, py): + def jacobian(self, xsi: float, eta: float, px: np.ndarray, py: np.ndarray): dphidxsi = [eta - 1, 1 - eta, eta, -eta] dphideta = [xsi - 1, -xsi, xsi, 1 - xsi] @@ -2285,7 +2303,7 @@ class NestedField(list): """ - def __init__(self, name, F, V=None, W=None): + def __init__(self, name: str, F, V=None, W=None): if V is None: if isinstance(F[0], VectorField): vector_type = F[0].vector_type diff --git a/parcels/fieldset.py b/parcels/fieldset.py index be5578ebc..58691e7a1 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -7,15 +7,17 @@ import dask.array as da import numpy as np +from parcels._typing import Mesh, TimePeriodic from parcels.field import DeferredArray, Field, NestedField, VectorField from parcels.grid import Grid from parcels.gridset import GridSet +from parcels.particlefile import ParticleFile from parcels.tools.converters import TimeConverter, convert_xarray_time_units from parcels.tools.loggers import logger from parcels.tools.statuscodes import TimeExtrapolationError try: - from mpi4py import MPI + from mpi4py import MPI # pyright: ignore[reportMissingImports] except ModuleNotFoundError: MPI = None @@ -37,13 +39,14 @@ class FieldSet: in custom kernels. """ - def __init__(self, U, V, fields=None): + def __init__(self, U: Field | NestedField | None, V: Field | NestedField | None, fields=None): self.gridset = GridSet() - self.completed = False - self.particlefile = None + self.completed: bool = False + self.particlefile: ParticleFile | None = None if U: self.add_field(U, "U") - self.time_origin = self.U.grid.time_origin if isinstance(self.U, Field) else self.U[0].grid.time_origin + # see #1663 for type-ignore reason + self.time_origin = self.U.grid.time_origin if isinstance(self.U, Field) else self.U[0].grid.time_origin # type: ignore if V: self.add_field(V, "V") @@ -67,9 +70,9 @@ def from_data( data, dimensions, transpose=False, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, **kwargs, ): """Initialise FieldSet object from raw data. @@ -159,7 +162,7 @@ def from_data( v = fields.pop("V", None) return cls(u, v, fields=fields) - def add_field(self, field, name=None): + def add_field(self, field: Field | NestedField, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. Parameters @@ -167,7 +170,8 @@ def add_field(self, field, name=None): field : parcels.field.Field Field object to be added name : str - Name of the :class:`parcels.field.Field` object to be added + Name of the :class:`parcels.field.Field` object to be added. Defaults + to name in Field object. Examples @@ -184,6 +188,7 @@ def add_field(self, field, name=None): "FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?" ) name = field.name if name is None else name + if hasattr(self, name): # check if Field with same name already exists when adding new Field raise RuntimeError(f"FieldSet already has a Field with name '{name}'") if isinstance(field, NestedField): @@ -196,7 +201,7 @@ def add_field(self, field, name=None): self.gridset.add_grid(field) field.fieldset = self - def add_constant_field(self, name, value, mesh="flat"): + def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): """Wrapper function to add a Field that is constant in space, useful e.g. when using constant horizontal diffusivity @@ -342,10 +347,10 @@ def from_netcdf( dimensions, indices=None, fieldtype=None, - mesh="spherical", + mesh: Mesh = "spherical", timestamps=None, - allow_time_extrapolation=None, - time_periodic=False, + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, deferred_load=True, chunksize=None, **kwargs, @@ -435,10 +440,10 @@ def from_netcdf( """ # Ensure that times are not provided both in netcdf file and in 'timestamps'. if timestamps is not None and "time" in dimensions: - logger.warning_once("Time already provided, defaulting to dimensions['time'] over timestamps.") + logger.warning_once("Time already provided, defaulting to dimensions['time'] over timestamps.") # type: ignore timestamps = None - fields = {} + fields: dict[str, Field] = {} if "creation_log" not in kwargs.keys(): kwargs["creation_log"] = "from_netcdf" for var, name in variables.items(): @@ -521,9 +526,9 @@ def from_nemo( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, tracer_interp_method="cgrid_tracer", chunksize=None, **kwargs, @@ -632,9 +637,9 @@ def from_mitgcm( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, tracer_interp_method="cgrid_tracer", chunksize=None, **kwargs, @@ -682,9 +687,9 @@ def from_c_grid_dataset( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, tracer_interp_method="cgrid_tracer", gridindexingtype="nemo", chunksize=None, @@ -764,12 +769,12 @@ def from_c_grid_dataset( if "U" in dimensions and "V" in dimensions and dimensions["U"] != dimensions["V"]: raise ValueError( "On a C-grid, the dimensions of velocities should be the corners (f-points) of the cells, so the same for U and V. " - "See also ../examples/documentation_indexing.ipynb" + "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" ) if "U" in dimensions and "W" in dimensions and dimensions["U"] != dimensions["W"]: raise ValueError( "On a C-grid, the dimensions of velocities should be the corners (f-points) of the cells, so the same for U, V and W. " - "See also ../examples/documentation_indexing.ipynb" + "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" ) if "interp_method" in kwargs.keys(): raise TypeError("On a C-grid, the interpolation method for velocities should not be overridden") @@ -804,9 +809,9 @@ def from_pop( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, tracer_interp_method="bgrid_tracer", chunksize=None, depth_units="m", @@ -909,7 +914,7 @@ def from_pop( if hasattr(fieldset, "W"): if depth_units == "m": fieldset.W.set_scaling_factor(-0.01) # cm/s to m/s and change the W direction - logger.warning_once( + logger.warning_once( # type: ignore "Parcels assumes depth in POP output to be in 'm'. Use depth_units='cm' if the output depth is in 'cm'." ) elif depth_units == "cm": @@ -925,9 +930,9 @@ def from_mom5( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, tracer_interp_method="bgrid_tracer", chunksize=None, **kwargs, @@ -1049,9 +1054,9 @@ def from_b_grid_dataset( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, tracer_interp_method="bgrid_tracer", chunksize=None, **kwargs, @@ -1126,12 +1131,12 @@ def from_b_grid_dataset( if "U" in dimensions and "V" in dimensions and dimensions["U"] != dimensions["V"]: raise ValueError( "On a B-grid, the dimensions of velocities should be the (top) corners of the grid cells, so the same for U and V. " - "See also ../examples/documentation_indexing.ipynb" + "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" ) if "U" in dimensions and "W" in dimensions and dimensions["U"] != dimensions["W"]: raise ValueError( "On a B-grid, the dimensions of velocities should be the (top) corners of the grid cells, so the same for U, V and W. " - "See also ../examples/documentation_indexing.ipynb" + "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" ) interp_method = {} @@ -1166,8 +1171,8 @@ def from_parcels( vvar="vomecrty", indices=None, extra_fields=None, - allow_time_extrapolation=None, - time_periodic=False, + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, deferred_load=True, chunksize=None, **kwargs, diff --git a/parcels/grid.py b/parcels/grid.py index ecdcc2086..c5630237f 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -1,12 +1,17 @@ import functools from ctypes import POINTER, Structure, c_double, c_float, c_int, c_void_p, cast, pointer from enum import IntEnum +from typing import TYPE_CHECKING import numpy as np +from parcels._typing import Mesh from parcels.tools.converters import TimeConverter from parcels.tools.loggers import logger +if TYPE_CHECKING: + import numpy.typing as npt + __all__ = [ "GridType", "GridCode", @@ -38,7 +43,14 @@ class CGrid(Structure): class Grid: """Grid class that defines a (spatial and temporal) grid on which Fields are defined.""" - def __init__(self, lon, lat, time, time_origin, mesh): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + time: npt.NDArray | None, + time_origin: TimeConverter | None, + mesh: Mesh, + ): self.xi = None self.yi = None self.zi = None @@ -66,7 +78,7 @@ def __init__(self, lon, lat, time, time_origin, mesh): assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object" self.mesh = mesh self.cstruct = None - self.cell_edge_sizes = {} + self.cell_edge_sizes: dict[str, npt.NDArray] = {} self.zonal_periodic = False self.zonal_halo = 0 self.meridional_halo = 0 @@ -76,14 +88,22 @@ def __init__(self, lon, lat, time, time_origin, mesh): [np.nanmin(lon), np.nanmax(lon), np.nanmin(lat), np.nanmax(lat)], dtype=np.float32 ) self.periods = 0 - self.load_chunk = [] + self.load_chunk: npt.NDArray = np.array([]) self.chunk_info = None self.chunksize = None self._add_last_periodic_data_timestep = False self.depth_field = None @staticmethod - def create_grid(lon, lat, depth, time, time_origin, mesh, **kwargs): + def create_grid( + lon: npt.ArrayLike, + lat: npt.ArrayLike, + depth, + time, + time_origin, + mesh: Mesh, + **kwargs, + ): lon = np.array(lon) lat = np.array(lat) @@ -313,7 +333,7 @@ class RectilinearGrid(Grid): """ - def __init__(self, lon, lat, time, time_origin, mesh): + def __init__(self, lon, lat, time, time_origin, mesh: Mesh): assert isinstance(lon, np.ndarray) and len(lon.shape) <= 1, "lon is not a numpy vector" assert isinstance(lat, np.ndarray) and len(lat.shape) <= 1, "lat is not a numpy vector" assert isinstance(time, np.ndarray) or not time, "time is not a numpy array" @@ -327,7 +347,7 @@ def __init__(self, lon, lat, time, time_origin, mesh): if self.ydim > 1 and self.lat[-1] < self.lat[0]: self.lat = np.flip(self.lat, axis=0) self.lat_flipped = True - logger.warning_once( + logger.warning_once( # type: ignore "Flipping lat data from North-South to South-North. " "Note that this may lead to wrong sign for meridional velocity, so tread very carefully" ) @@ -396,7 +416,7 @@ class RectilinearZGrid(RectilinearGrid): 2. flat: No conversion, lat/lon are assumed to be in m. """ - def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh="flat"): + def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh: Mesh = "flat"): super().__init__(lon, lat, time, time_origin, mesh) if isinstance(depth, np.ndarray): assert len(depth.shape) <= 1, "depth is not a vector" @@ -442,7 +462,15 @@ class RectilinearSGrid(RectilinearGrid): 2. flat: No conversion, lat/lon are assumed to be in m. """ - def __init__(self, lon, lat, depth, time=None, time_origin=None, mesh="flat"): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + depth: npt.NDArray, + time: npt.NDArray | None = None, + time_origin: TimeConverter | None = None, + mesh: Mesh = "flat", + ): super().__init__(lon, lat, time, time_origin, mesh) assert isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4], "depth is not a 3D or 4D numpy array" @@ -477,7 +505,14 @@ def __init__(self, lon, lat, depth, time=None, time_origin=None, mesh="flat"): class CurvilinearGrid(Grid): - def __init__(self, lon, lat, time=None, time_origin=None, mesh="flat"): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + time: npt.NDArray | None = None, + time_origin: TimeConverter | None = None, + mesh: Mesh = "flat", + ): assert isinstance(lon, np.ndarray) and len(lon.squeeze().shape) == 2, "lon is not a 2D numpy array" assert isinstance(lat, np.ndarray) and len(lat.squeeze().shape) == 2, "lat is not a 2D numpy array" assert isinstance(time, np.ndarray) or not time, "time is not a numpy array" @@ -574,7 +609,15 @@ class CurvilinearZGrid(CurvilinearGrid): 2. flat: No conversion, lat/lon are assumed to be in m. """ - def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh="flat"): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + depth: npt.NDArray | None = None, + time: npt.NDArray | None = None, + time_origin: TimeConverter | None = None, + mesh: Mesh = "flat", + ): super().__init__(lon, lat, time, time_origin, mesh) if isinstance(depth, np.ndarray): assert len(depth.shape) == 1, "depth is not a vector" @@ -619,7 +662,15 @@ class CurvilinearSGrid(CurvilinearGrid): 2. flat: No conversion, lat/lon are assumed to be in m. """ - def __init__(self, lon, lat, depth, time=None, time_origin=None, mesh="flat"): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + depth: npt.NDArray, + time: npt.NDArray | None = None, + time_origin: TimeConverter | None = None, + mesh: Mesh = "flat", + ): super().__init__(lon, lat, time, time_origin, mesh) assert isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4], "depth is not a 4D numpy array" diff --git a/parcels/kernel.py b/parcels/kernel.py index b144abc72..0adb77808 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -1,4 +1,5 @@ import _ctypes +import abc import ast import functools import hashlib @@ -45,7 +46,7 @@ __all__ = ["Kernel", "BaseKernel"] -class BaseKernel: +class BaseKernel(abc.ABC): """Superclass for 'normal' and Interactive Kernels""" def __init__( @@ -139,6 +140,9 @@ def remove_deleted(self, pset): self.fieldset.particlefile.write(pset, None, indices=indices) pset.remove_indices(indices) + @abc.abstractmethod + def get_kernel_compile_files(self): ... + class Kernel(BaseKernel): """Kernel object that encapsulates auto-generated code. diff --git a/pyproject.toml b/pyproject.toml index 613ff9978..11996b56e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,6 +128,8 @@ files = [ "parcels/_typing.py", "parcels/tools/*.py", "parcels/grid.py", + "parcels/field.py", + "parcels/fieldset.py", ] [[tool.mypy.overrides]] From 91192c75cbcc0eff1d8450ffe338f318360796b7 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:03:18 +0200 Subject: [PATCH 4/9] Add ruff badge --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 511f8e5fc..1e155dcf7 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/OceanParcels/parcels/master?labpath=docs%2Fexamples%2Fparcels_tutorial.ipynb) [![unit-tests](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml/badge.svg)](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml) +[![Code style: Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/format.json)](https://github.com/astral-sh/ruff) [![codecov](https://codecov.io/gh/OceanParcels/parcels/branch/master/graph/badge.svg)](https://codecov.io/gh/OceanParcels/parcels) [![Anaconda-release](https://anaconda.org/conda-forge/parcels/badges/version.svg)](https://anaconda.org/conda-forge/parcels/) [![Anaconda-date](https://anaconda.org/conda-forge/parcels/badges/latest_release_date.svg)](https://anaconda.org/conda-forge/parcels/) From f112090a5058466180844da72f167b37acf81685 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Mon, 26 Aug 2024 17:30:45 +0200 Subject: [PATCH 5/9] Add _compat.py file _compat.py file was mentioned in https://learn.scientific-python.org/development/patterns/backports/#placement-in-a-file . Thought it would be useful to consolidate MPI stuff --- parcels/_compat.py | 14 ++++++++++++++ parcels/compilation/codecompiler.py | 5 +---- parcels/fieldset.py | 7 +------ parcels/interaction/interactionkernel.py | 6 +----- parcels/kernel.py | 6 +----- parcels/particledata.py | 11 +---------- parcels/particlefile.py | 7 +------ parcels/particleset.py | 7 +------ parcels/tools/timer.py | 5 +---- 9 files changed, 22 insertions(+), 46 deletions(-) create mode 100644 parcels/_compat.py diff --git a/parcels/_compat.py b/parcels/_compat.py new file mode 100644 index 000000000..981ef30a5 --- /dev/null +++ b/parcels/_compat.py @@ -0,0 +1,14 @@ +"""Import helpers for compatability between installations.""" + +__all__ = ["MPI", "KMeans"] + +try: + from mpi4py import MPI +except ModuleNotFoundError: + MPI = None + +# KMeans is used in MPI. sklearn not installed by default +try: + from sklearn.cluster import KMeans +except ModuleNotFoundError: + KMeans = None diff --git a/parcels/compilation/codecompiler.py b/parcels/compilation/codecompiler.py index 406daf597..794d4cfd9 100644 --- a/parcels/compilation/codecompiler.py +++ b/parcels/compilation/codecompiler.py @@ -2,10 +2,7 @@ import subprocess from struct import calcsize -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None +from parcels._compat import MPI _tmp_dir = os.getcwd() diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 58691e7a1..b47713b54 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -7,6 +7,7 @@ import dask.array as da import numpy as np +from parcels._compat import MPI from parcels._typing import Mesh, TimePeriodic from parcels.field import DeferredArray, Field, NestedField, VectorField from parcels.grid import Grid @@ -16,12 +17,6 @@ from parcels.tools.loggers import logger from parcels.tools.statuscodes import TimeExtrapolationError -try: - from mpi4py import MPI # pyright: ignore[reportMissingImports] -except ModuleNotFoundError: - MPI = None - - __all__ = ["FieldSet"] diff --git a/parcels/interaction/interactionkernel.py b/parcels/interaction/interactionkernel.py index c07af94d6..dce302e21 100644 --- a/parcels/interaction/interactionkernel.py +++ b/parcels/interaction/interactionkernel.py @@ -4,11 +4,7 @@ import numpy as np -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None - +from parcels._compat import MPI from parcels.field import NestedField, VectorField from parcels.kernel import BaseKernel from parcels.tools.loggers import logger diff --git a/parcels/kernel.py b/parcels/kernel.py index 0adb77808..fff3571d1 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -18,13 +18,9 @@ import numpy.ctypeslib as npct from numpy import ndarray -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None - import parcels.rng as ParcelsRandom # noqa from parcels import rng # noqa +from parcels._compat import MPI from parcels.application_kernels.advection import ( AdvectionAnalytical, AdvectionRK4_3D, diff --git a/parcels/particledata.py b/parcels/particledata.py index 1f32cfa26..2d5964f9c 100644 --- a/parcels/particledata.py +++ b/parcels/particledata.py @@ -3,19 +3,10 @@ import numpy as np +from parcels._compat import MPI, KMeans from parcels.tools.loggers import logger from parcels.tools.statuscodes import StatusCode -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None -if MPI: - try: - from sklearn.cluster import KMeans - except: - KMeans = None - def partitionParticlesMPI_default(coords, mpi_size=1): """This function takes the coordinates of the particle starting diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 6e047130b..588e597d3 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -8,14 +8,9 @@ import zarr import parcels +from parcels._compat import MPI from parcels.tools.loggers import logger -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None - - __all__ = ["ParticleFile"] diff --git a/parcels/particleset.py b/parcels/particleset.py index c4007cf20..5d9c6a571 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -9,12 +9,7 @@ from scipy.spatial import KDTree from tqdm import tqdm -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None - - +from parcels._compat import MPI from parcels.application_kernels.advection import AdvectionRK4 from parcels.compilation.codecompiler import GNUCompiler from parcels.field import NestedField diff --git a/parcels/tools/timer.py b/parcels/tools/timer.py index 6aa7701e3..8896aa42a 100644 --- a/parcels/tools/timer.py +++ b/parcels/tools/timer.py @@ -1,10 +1,7 @@ import datetime import time -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None +from parcels._compat import MPI __all__ = [] # type: ignore From 5690b76f16ff3a52358033e98b885bc301050e13 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 27 Aug 2024 11:58:29 +0200 Subject: [PATCH 6/9] Remove if TYPE_CHECKING block Need to investigate the use of `TYPE_CHECKING` as it causes runtime `NameErrors` even when only used as annotations. --- parcels/grid.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/parcels/grid.py b/parcels/grid.py index c5630237f..27cb3a285 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -1,17 +1,14 @@ import functools from ctypes import POINTER, Structure, c_double, c_float, c_int, c_void_p, cast, pointer from enum import IntEnum -from typing import TYPE_CHECKING import numpy as np +import numpy.typing as npt from parcels._typing import Mesh from parcels.tools.converters import TimeConverter from parcels.tools.loggers import logger -if TYPE_CHECKING: - import numpy.typing as npt - __all__ = [ "GridType", "GridCode", From 4dcd89eaf3a464c87dff9491cb9c0ea106bb39f2 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 27 Aug 2024 13:24:16 +0200 Subject: [PATCH 7/9] Implement some suggestions from repo-review tool - Add concurrency for workflows (so repeated pushes on the same branch cancel previous workflow run) - Add dependabot for GHA version updating - Add prettier as formatter (removing Biome as it didn't format the needed files) - --show-fixes in ruff config Used https://learn.scientific-python.org/development/guides/repo-review/ to highlight suggested changes in project configuration. Some items (upload-artifact version change, nox/tox integration, spell checker, pytest options) weren't included as they add maintainence burden beyond their benefits, or hinder development (codespell). --- .github/dependabot.yml | 6 ++++++ .github/workflows/ci.yml | 6 ++++-- .pre-commit-config.yaml | 8 ++++---- codecov.yml | 1 - environment.yml | 1 - pyproject.toml | 1 - 6 files changed, 14 insertions(+), 9 deletions(-) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..8ac6b8c49 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 41b91bfe8..de52b7262 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,11 +5,13 @@ on: - "master" - "test-me/*" pull_request: - branches: - - "*" schedule: - cron: "0 7 * * 1" # Run every Monday at 7:00 UTC +concurrency: + group: branch-${{ github.head_ref }} + cancel-in-progress: true + defaults: run: shell: bash -el {0} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eeae58c68..15a452cc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,11 +13,11 @@ repos: rev: v0.5.6 hooks: - id: ruff - args: [ --fix ] + args: [--fix, --show-fixes] - id: ruff name: ruff (isort jupyter) args: [--select, I, --fix] - types_or: [ jupyter ] + types_or: [jupyter] - id: ruff-format types_or: [ python, jupyter ] - repo: https://github.com/biomejs/pre-commit @@ -31,9 +31,9 @@ repos: hooks: - id: flake8 name: pydoclint - files: 'none' + files: "none" # files: parcels/fieldset.py # put here instead of in config file due to https://github.com/pre-commit/pre-commit-hooks/issues/112#issuecomment-215613842 args: - - --select=DOC103 # TODO: Expand coverage to other codes + - --select=DOC103 # TODO: Expand coverage to other codes additional_dependencies: - pydoclint[flake8] diff --git a/codecov.yml b/codecov.yml index 3ba50f502..c136680e5 100644 --- a/codecov.yml +++ b/codecov.yml @@ -13,6 +13,5 @@ comment: require_base: false require_head: true hide_project_coverage: true - # When modifying this file, please validate using # curl -X POST --data-binary @codecov.yml https://codecov.io/validate diff --git a/environment.yml b/environment.yml index d096b3c3a..93906ce11 100644 --- a/environment.yml +++ b/environment.yml @@ -38,7 +38,6 @@ dependencies: - types-tqdm - types-psutil - # Linting - pre_commit diff --git a/pyproject.toml b/pyproject.toml index 11996b56e..134a6f8ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ local_scheme = "no-local-version" python_files = ["test_*.py", "example_*.py", "*tutorial*"] [tool.ruff] -target-version = "py310" line-length = 120 [tool.ruff.lint] From 5f85aa71ec7c6ddc68b846bb3f0331e6a6ebd3f0 Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Tue, 27 Aug 2024 14:19:09 +0200 Subject: [PATCH 8/9] Remove GHA keys (not supported by miniconda, only micromamba) and typing --- .github/actions/install-parcels/action.yml | 2 -- parcels/_compat.py | 13 +++++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/actions/install-parcels/action.yml b/.github/actions/install-parcels/action.yml index a7aea0602..66a3bbccc 100644 --- a/.github/actions/install-parcels/action.yml +++ b/.github/actions/install-parcels/action.yml @@ -24,8 +24,6 @@ runs: environment-file: ${{ inputs.environment-file }} python-version: ${{ inputs.python-version }} channels: conda-forge - cache-environment: true - cache-downloads: true - name: MPI support if: ${{ ! (runner.os == 'Windows') }} run: conda install -c conda-forge mpich mpi4py diff --git a/parcels/_compat.py b/parcels/_compat.py index 981ef30a5..6efab15a7 100644 --- a/parcels/_compat.py +++ b/parcels/_compat.py @@ -2,13 +2,18 @@ __all__ = ["MPI", "KMeans"] +from typing import Any + +MPI: Any | None = None +KMeans: Any | None = None + try: - from mpi4py import MPI + from mpi4py import MPI # type: ignore[no-redef] except ModuleNotFoundError: - MPI = None + pass # KMeans is used in MPI. sklearn not installed by default try: - from sklearn.cluster import KMeans + from sklearn.cluster import KMeans # type: ignore[no-redef] except ModuleNotFoundError: - KMeans = None + pass From 6fa05152bab595439800a9161e5be89a799978fd Mon Sep 17 00:00:00 2001 From: Vecko <36369090+VeckoTheGecko@users.noreply.github.com> Date: Wed, 28 Aug 2024 10:39:57 +0200 Subject: [PATCH 9/9] Review suggestions --- README.md | 8 +++--- parcels/_typing.py | 53 +++++++++++++++++-------------------- parcels/field.py | 14 +++++----- parcels/fieldfilebuffer.py | 10 ++++++- parcels/fieldset.py | 18 ++++++------- parcels/tools/converters.py | 4 +-- 6 files changed, 55 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 1e155dcf7..6ddac8f81 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ ## Parcels -[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/OceanParcels/parcels/master?labpath=docs%2Fexamples%2Fparcels_tutorial.ipynb) -[![unit-tests](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml/badge.svg)](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml) -[![Code style: Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/format.json)](https://github.com/astral-sh/ruff) -[![codecov](https://codecov.io/gh/OceanParcels/parcels/branch/master/graph/badge.svg)](https://codecov.io/gh/OceanParcels/parcels) [![Anaconda-release](https://anaconda.org/conda-forge/parcels/badges/version.svg)](https://anaconda.org/conda-forge/parcels/) [![Anaconda-date](https://anaconda.org/conda-forge/parcels/badges/latest_release_date.svg)](https://anaconda.org/conda-forge/parcels/) [![Zenodo](https://zenodo.org/badge/DOI/10.5281/zenodo.823561.svg)](https://doi.org/10.5281/zenodo.823561) +[![Code style: Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/format.json)](https://github.com/astral-sh/ruff) +[![unit-tests](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml/badge.svg)](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml) +[![codecov](https://codecov.io/gh/OceanParcels/parcels/branch/master/graph/badge.svg)](https://codecov.io/gh/OceanParcels/parcels) [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/5353/badge)](https://bestpractices.coreinfrastructure.org/projects/5353) +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/OceanParcels/parcels/master?labpath=docs%2Fexamples%2Fparcels_tutorial.ipynb) **Parcels** (**P**robably **A** **R**eally **C**omputationally **E**fficient **L**agrangian **S**imulator) is a set of Python classes and methods to create customisable particle tracking simulations using output from Ocean Circulation models. Parcels can be used to track passive and active particulates such as water, plankton, [plastic](http://www.topios.org/) and [fish](https://github.com/Jacketless/IKAMOANA). diff --git a/parcels/_typing.py b/parcels/_typing.py index f67192bcb..2e7ace119 100644 --- a/parcels/_typing.py +++ b/parcels/_typing.py @@ -9,42 +9,37 @@ import ast import datetime import os -from typing import Any, Callable, Literal, get_args +from typing import Callable, Literal class ParcelsAST(ast.AST): ccode: str -# InterpMethod = InterpMethodOption | dict[str, InterpMethodOption] # (can also be a dict, search for `if type(interp_method) is dict`) -# InterpMethodOption = Literal[ -# "nearest", -# "freeslip", -# "partialslip", -# "bgrid_velocity", -# "bgrid_w_velocity", -# "cgrid_velocity", -# "linear_invdist_land_tracer", -# "nearest", -# "cgrid_tracer", -# ] # mostly corresponds with `interp_method` # TODO: This should be narrowed. Unlikely applies to every context +InterpMethodOption = Literal[ + "linear", + "nearest", + "freeslip", + "partialslip", + "bgrid_velocity", + "bgrid_w_velocity", + "cgrid_velocity", + "linear_invdist_land_tracer", + "nearest", + "bgrid_tracer", + "cgrid_tracer", +] # corresponds with `tracer_interp_method` +InterpMethod = ( + InterpMethodOption | dict[str, InterpMethodOption] +) # corresponds with `interp_method` (which can also be dict mapping field names to method) PathLike = str | os.PathLike -Mesh = Literal["spherical", "flat"] # mostly corresponds with `mesh` -VectorType = Literal["3D", "2D"] | None # mostly corresponds with `vector_type` -ChunkMode = Literal["auto", "specific", "failsafe"] # mostly corresponds with `chunk_mode` -GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # mostly corresponds with `grid_indexing_type` -UpdateStatus = Literal["not_updated", "first_updated", "updated"] # mostly corresponds with `update_status` -TimePeriodic = float | datetime.timedelta | Literal[False] # mostly corresponds with `update_status` -NetcdfEngine = Literal[ - "netcdf4", "xarray" -] # TODO: It seems that "scipy" is also an option (according to a docstring) but can't find mention in code. Investigate. +Mesh = Literal["spherical", "flat"] # corresponds with `mesh` +VectorType = Literal["3D", "2D"] | None # corresponds with `vector_type` +ChunkMode = Literal["auto", "specific", "failsafe"] # corresponds with `chunk_mode` +GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `grid_indexing_type` +UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `update_status` +TimePeriodic = float | datetime.timedelta | Literal[False] # corresponds with `update_status` +NetcdfEngine = Literal["netcdf4", "xarray", "scipy"] KernelFunction = Callable[..., None] - - -def ensure_is_literal_value(value: Any, literal: Any) -> None: - """Ensures that a value is a valid option for the provided Literal type annotation.""" - valid_options = get_args(literal) - if value not in valid_options: - raise ValueError(f"{value!r} is not a valid option. Valid options are {valid_options}") diff --git a/parcels/field.py b/parcels/field.py index 7495ccdd1..aa36fad46 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -10,7 +10,7 @@ import xarray as xr import parcels.tools.interpolation_utils as i_u -from parcels._typing import GridIndexingType, Mesh, TimePeriodic, VectorType +from parcels._typing import GridIndexingType, InterpMethod, Mesh, TimePeriodic, VectorType from parcels.tools.converters import ( Geographic, GeographicPolar, @@ -156,7 +156,7 @@ def __init__( vmax=None, cast_data_dtype="float32", time_origin=None, - interp_method="linear", + interp_method: InterpMethod = "linear", allow_time_extrapolation: bool | None = None, time_periodic: TimePeriodic = False, gridindexingtype: GridIndexingType = "nemo", @@ -199,7 +199,7 @@ def __init__( else: raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'") self.timestamps = timestamps - if type(interp_method) is dict: + if isinstance(interp_method, dict): if self.name in interp_method: self.interp_method = interp_method[self.name] else: @@ -215,7 +215,7 @@ def __init__( "General s-levels are not supported in B-grid. RectilinearSGrid and CurvilinearSGrid can still be used to deal with shaved cells, but the levels must be horizontal." ) - self.fieldset: FieldSet | None = None + self.fieldset: "FieldSet" | None = None if allow_time_extrapolation is None: self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False else: @@ -292,7 +292,7 @@ def __init__( # since some datasets do not provide the deeper level of data (which is ignored by the interpolation). self.data_full_zdim = kwargs.pop("data_full_zdim", None) self.data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays - self.c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array + self.c_data_chunks: list["PointerType" | None] = [] # C-pointers to the data_chunks array self.nchunks: tuple[int, ...] = () self.chunk_set: bool = False self.filebuffers = [None] * 2 @@ -489,7 +489,7 @@ def from_netcdf( "if you would need such feature implemented." ) - interp_method = kwargs.pop("interp_method", "linear") + interp_method: InterpMethod = kwargs.pop("interp_method", "linear") if type(interp_method) is dict: if variable[0] in interp_method: interp_method = interp_method[variable[0]] @@ -871,7 +871,7 @@ def search_indices_vertical_z(self, z): return (zi, zeta) def search_indices_vertical_s( - self, x: float, y: float, z: float, xi: int, yi: int, xsi: float, eta: float, ti: int, time + self, x: float, y: float, z: float, xi: int, yi: int, xsi: float, eta: float, ti: int, time: float ): grid = self.grid if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]: diff --git a/parcels/fieldfilebuffer.py b/parcels/fieldfilebuffer.py index dd702270a..c3b45fc51 100644 --- a/parcels/fieldfilebuffer.py +++ b/parcels/fieldfilebuffer.py @@ -9,6 +9,7 @@ from dask import utils as da_utils from netCDF4 import Dataset as ncDataset +from parcels._typing import InterpMethodOption from parcels.tools.converters import convert_xarray_time_units from parcels.tools.loggers import logger from parcels.tools.statuscodes import DaskChunkingError @@ -16,7 +17,14 @@ class _FileBuffer: def __init__( - self, filename, dimensions, indices, timestamp=None, interp_method="linear", data_full_zdim=None, **kwargs + self, + filename, + dimensions, + indices, + timestamp=None, + interp_method: InterpMethodOption = "linear", + data_full_zdim=None, + **kwargs, ): self.filename = filename self.dimensions = dimensions # Dict with dimension keys for file data diff --git a/parcels/fieldset.py b/parcels/fieldset.py index b47713b54..16b690f5c 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -8,7 +8,7 @@ import numpy as np from parcels._compat import MPI -from parcels._typing import Mesh, TimePeriodic +from parcels._typing import GridIndexingType, InterpMethodOption, Mesh, TimePeriodic from parcels.field import DeferredArray, Field, NestedField, VectorField from parcels.grid import Grid from parcels.gridset import GridSet @@ -415,7 +415,7 @@ def from_netcdf( ``{parcels_varname: {netcdf_dimname : (parcels_dimname, chunksize_as_int)}, ...}``, where ``parcels_dimname`` is one of ('time', 'depth', 'lat', 'lon') netcdf_engine : engine to use for netcdf reading in xarray. Default is 'netcdf', - but in cases where this doesn't work, setting netcdf_engine='scipy' could help + but in cases where this doesn't work, setting netcdf_engine='scipy' could help. Accepted options are the same as the ``engine`` parameter in ``xarray.open_dataset()``. **kwargs : Keyword arguments passed to the :class:`parcels.Field` constructor. @@ -524,7 +524,7 @@ def from_nemo( mesh: Mesh = "spherical", allow_time_extrapolation: bool | None = None, time_periodic: TimePeriodic = False, - tracer_interp_method="cgrid_tracer", + tracer_interp_method: InterpMethodOption = "cgrid_tracer", chunksize=None, **kwargs, ): @@ -635,7 +635,7 @@ def from_mitgcm( mesh: Mesh = "spherical", allow_time_extrapolation: bool | None = None, time_periodic: TimePeriodic = False, - tracer_interp_method="cgrid_tracer", + tracer_interp_method: InterpMethodOption = "cgrid_tracer", chunksize=None, **kwargs, ): @@ -685,8 +685,8 @@ def from_c_grid_dataset( mesh: Mesh = "spherical", allow_time_extrapolation: bool | None = None, time_periodic: TimePeriodic = False, - tracer_interp_method="cgrid_tracer", - gridindexingtype="nemo", + tracer_interp_method: InterpMethodOption = "cgrid_tracer", + gridindexingtype: GridIndexingType = "nemo", chunksize=None, **kwargs, ): @@ -807,7 +807,7 @@ def from_pop( mesh: Mesh = "spherical", allow_time_extrapolation: bool | None = None, time_periodic: TimePeriodic = False, - tracer_interp_method="bgrid_tracer", + tracer_interp_method: InterpMethodOption = "bgrid_tracer", chunksize=None, depth_units="m", **kwargs, @@ -928,7 +928,7 @@ def from_mom5( mesh: Mesh = "spherical", allow_time_extrapolation: bool | None = None, time_periodic: TimePeriodic = False, - tracer_interp_method="bgrid_tracer", + tracer_interp_method: InterpMethodOption = "bgrid_tracer", chunksize=None, **kwargs, ): @@ -1052,7 +1052,7 @@ def from_b_grid_dataset( mesh: Mesh = "spherical", allow_time_extrapolation: bool | None = None, time_periodic: TimePeriodic = False, - tracer_interp_method="bgrid_tracer", + tracer_interp_method: InterpMethodOption = "bgrid_tracer", chunksize=None, **kwargs, ): diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index f657f593b..de6705fac 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -5,8 +5,8 @@ import cftime import numpy as np +import numpy.typing as npt import xarray as xr -from numpy.typing import ArrayLike, NDArray __all__ = [ "UnitConverter", @@ -21,7 +21,7 @@ ] -def convert_to_flat_array(var: ArrayLike) -> NDArray: +def convert_to_flat_array(var: npt.ArrayLike) -> npt.NDArray: """Convert lists and single integers/floats to one-dimensional numpy arrays Parameters