diff --git a/.github/actions/install-parcels/action.yml b/.github/actions/install-parcels/action.yml index a7aea0602..66a3bbccc 100644 --- a/.github/actions/install-parcels/action.yml +++ b/.github/actions/install-parcels/action.yml @@ -24,8 +24,6 @@ runs: environment-file: ${{ inputs.environment-file }} python-version: ${{ inputs.python-version }} channels: conda-forge - cache-environment: true - cache-downloads: true - name: MPI support if: ${{ ! (runner.os == 'Windows') }} run: conda install -c conda-forge mpich mpi4py diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..8ac6b8c49 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c0a62e3a..de52b7262 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,11 +5,13 @@ on: - "master" - "test-me/*" pull_request: - branches: - - "*" schedule: - cron: "0 7 * * 1" # Run every Monday at 7:00 UTC +concurrency: + group: branch-${{ github.head_ref }} + cancel-in-progress: true + defaults: run: shell: bash -el {0} @@ -81,3 +83,25 @@ jobs: with: name: Integration test report path: ${{ matrix.os }}_integration_test_report.html + typechecking: + name: mypy + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Conda and parcels + uses: ./.github/actions/install-parcels + with: + environment-file: environment.yml + - run: conda install lxml # dep for report generation + - name: Typechecking + run: | + mypy --install-types --non-interactive parcels --cobertura-xml-report mypy_report + - name: Upload mypy coverage to Codecov + uses: codecov/codecov-action@v3.1.1 + if: ${{ always() }} # Upload even on error of mypy + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: mypy_report/cobertura.xml + flags: mypy + fail_ci_if_error: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccc8a6e5f..15a452cc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,14 +13,27 @@ repos: rev: v0.5.6 hooks: - id: ruff - args: [ --fix ] + args: [--fix, --show-fixes] - id: ruff name: ruff (isort jupyter) args: [--select, I, --fix] - types_or: [ jupyter ] + types_or: [jupyter] - id: ruff-format types_or: [ python, jupyter ] - repo: https://github.com/biomejs/pre-commit rev: v0.4.0 hooks: - id: biome-format + + # Ruff doesn't have full coverage of pydoclint https://github.com/astral-sh/ruff/issues/12434 + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + name: pydoclint + files: "none" + # files: parcels/fieldset.py # put here instead of in config file due to https://github.com/pre-commit/pre-commit-hooks/issues/112#issuecomment-215613842 + args: + - --select=DOC103 # TODO: Expand coverage to other codes + additional_dependencies: + - pydoclint[flake8] diff --git a/README.md b/README.md index 10e6ae196..6ddac8f81 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ ## Parcels -[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/OceanParcels/parcels/master?labpath=docs%2Fexamples%2Fparcels_tutorial.ipynb) -[![unit-tests](https://github.com/OceanParcels/parcels/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/OceanParcels/parcels/actions/workflows/unit-tests.yml) -[![codecov](https://codecov.io/gh/OceanParcels/parcels/branch/master/graph/badge.svg)](https://codecov.io/gh/OceanParcels/parcels) [![Anaconda-release](https://anaconda.org/conda-forge/parcels/badges/version.svg)](https://anaconda.org/conda-forge/parcels/) [![Anaconda-date](https://anaconda.org/conda-forge/parcels/badges/latest_release_date.svg)](https://anaconda.org/conda-forge/parcels/) [![Zenodo](https://zenodo.org/badge/DOI/10.5281/zenodo.823561.svg)](https://doi.org/10.5281/zenodo.823561) +[![Code style: Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/format.json)](https://github.com/astral-sh/ruff) +[![unit-tests](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml/badge.svg)](https://github.com/OceanParcels/parcels/actions/workflows/ci.yml) +[![codecov](https://codecov.io/gh/OceanParcels/parcels/branch/master/graph/badge.svg)](https://codecov.io/gh/OceanParcels/parcels) [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/5353/badge)](https://bestpractices.coreinfrastructure.org/projects/5353) +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/OceanParcels/parcels/master?labpath=docs%2Fexamples%2Fparcels_tutorial.ipynb) **Parcels** (**P**robably **A** **R**eally **C**omputationally **E**fficient **L**agrangian **S**imulator) is a set of Python classes and methods to create customisable particle tracking simulations using output from Ocean Circulation models. Parcels can be used to track passive and active particulates such as water, plankton, [plastic](http://www.topios.org/) and [fish](https://github.com/Jacketless/IKAMOANA). diff --git a/codecov.yml b/codecov.yml index 3ba50f502..c136680e5 100644 --- a/codecov.yml +++ b/codecov.yml @@ -13,6 +13,5 @@ comment: require_base: false require_head: true hide_project_coverage: true - # When modifying this file, please validate using # curl -X POST --data-binary @codecov.yml https://codecov.io/validate diff --git a/environment.yml b/environment.yml index e244acca1..93906ce11 100644 --- a/environment.yml +++ b/environment.yml @@ -33,10 +33,13 @@ dependencies: - pytest-html - coverage + # Typing + - mypy + - types-tqdm + - types-psutil + # Linting - - flake8>=2.1.0 - pre_commit - - pydocstyle # Docs - ipython diff --git a/parcels/_compat.py b/parcels/_compat.py new file mode 100644 index 000000000..6efab15a7 --- /dev/null +++ b/parcels/_compat.py @@ -0,0 +1,19 @@ +"""Import helpers for compatability between installations.""" + +__all__ = ["MPI", "KMeans"] + +from typing import Any + +MPI: Any | None = None +KMeans: Any | None = None + +try: + from mpi4py import MPI # type: ignore[no-redef] +except ModuleNotFoundError: + pass + +# KMeans is used in MPI. sklearn not installed by default +try: + from sklearn.cluster import KMeans # type: ignore[no-redef] +except ModuleNotFoundError: + pass diff --git a/parcels/_typing.py b/parcels/_typing.py new file mode 100644 index 000000000..2e7ace119 --- /dev/null +++ b/parcels/_typing.py @@ -0,0 +1,45 @@ +""" +Typing support for Parcels. + +This module contains type aliases used throughout Parcels as well as functions that are +used for runtime parameter validation (to ensure users are only using the right params). + +""" + +import ast +import datetime +import os +from typing import Callable, Literal + + +class ParcelsAST(ast.AST): + ccode: str + + +InterpMethodOption = Literal[ + "linear", + "nearest", + "freeslip", + "partialslip", + "bgrid_velocity", + "bgrid_w_velocity", + "cgrid_velocity", + "linear_invdist_land_tracer", + "nearest", + "bgrid_tracer", + "cgrid_tracer", +] # corresponds with `tracer_interp_method` +InterpMethod = ( + InterpMethodOption | dict[str, InterpMethodOption] +) # corresponds with `interp_method` (which can also be dict mapping field names to method) +PathLike = str | os.PathLike +Mesh = Literal["spherical", "flat"] # corresponds with `mesh` +VectorType = Literal["3D", "2D"] | None # corresponds with `vector_type` +ChunkMode = Literal["auto", "specific", "failsafe"] # corresponds with `chunk_mode` +GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `grid_indexing_type` +UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `update_status` +TimePeriodic = float | datetime.timedelta | Literal[False] # corresponds with `update_status` +NetcdfEngine = Literal["netcdf4", "xarray", "scipy"] + + +KernelFunction = Callable[..., None] diff --git a/parcels/compilation/codecompiler.py b/parcels/compilation/codecompiler.py index 406daf597..794d4cfd9 100644 --- a/parcels/compilation/codecompiler.py +++ b/parcels/compilation/codecompiler.py @@ -2,10 +2,7 @@ import subprocess from struct import calcsize -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None +from parcels._compat import MPI _tmp_dir = os.getcwd() diff --git a/parcels/compilation/codegenerator.py b/parcels/compilation/codegenerator.py index 2569dcc6c..921b422cb 100644 --- a/parcels/compilation/codegenerator.py +++ b/parcels/compilation/codegenerator.py @@ -410,7 +410,7 @@ class KernelGenerator(ABC, ast.NodeVisitor): # Intrinsic variables that appear as function arguments kernel_vars = ["particle", "fieldset", "time", "output_time", "tol"] - array_vars = [] + array_vars: list[str] = [] def __init__(self, fieldset=None, ptype=JITParticle): self.fieldset = fieldset @@ -419,7 +419,7 @@ def __init__(self, fieldset=None, ptype=JITParticle): self.vector_field_args = collections.OrderedDict() self.const_args = collections.OrderedDict() - def generate(self, py_ast, funcvars): + def generate(self, py_ast, funcvars: list[str]): # Replace occurrences of intrinsic objects in Python AST transformer = IntrinsicTransformer(self.fieldset, self.ptype) py_ast = transformer.visit(py_ast) @@ -434,7 +434,7 @@ def generate(self, py_ast, funcvars): # Insert variable declarations for non-intrinsic variables # Make sure that repeated variables are not declared more than # once. If variables occur in multiple Kernels, give a warning - used_vars = [] + used_vars: list[str] = [] funcvars_copy = copy(funcvars) # editing a list while looping over it is dangerous for kvar in funcvars: if kvar in used_vars + ["particle_dlon", "particle_dlat", "particle_ddepth"]: diff --git a/parcels/field.py b/parcels/field.py index 651c62701..aa36fad46 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -3,12 +3,14 @@ import math from ctypes import POINTER, Structure, c_float, c_int, pointer from pathlib import Path +from typing import TYPE_CHECKING, Iterable, Type import dask.array as da import numpy as np import xarray as xr import parcels.tools.interpolation_utils as i_u +from parcels._typing import GridIndexingType, InterpMethod, Mesh, TimePeriodic, VectorType from parcels.tools.converters import ( Geographic, GeographicPolar, @@ -33,6 +35,11 @@ ) from .grid import CGrid, Grid, GridType +if TYPE_CHECKING: + from ctypes import _Pointer as PointerType + + from parcels.fieldset import FieldSet + __all__ = ["Field", "VectorField", "NestedField"] @@ -43,7 +50,7 @@ def _isParticle(key): return False -def _deal_with_errors(error, key, vector_type): +def _deal_with_errors(error, key, vector_type: VectorType): if _isParticle(key): key.state = AllParcelsErrorCodes[type(error)] elif _isParticle(key[-1]): @@ -134,14 +141,14 @@ class Field: def __init__( self, - name, + name: str | tuple[str, str], data, lon=None, lat=None, depth=None, time=None, grid=None, - mesh="flat", + mesh: Mesh = "flat", timestamps=None, fieldtype=None, transpose=False, @@ -149,10 +156,10 @@ def __init__( vmax=None, cast_data_dtype="float32", time_origin=None, - interp_method="linear", - allow_time_extrapolation=None, - time_periodic=False, - gridindexingtype="nemo", + interp_method: InterpMethod = "linear", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + gridindexingtype: GridIndexingType = "nemo", to_write=False, **kwargs, ): @@ -160,7 +167,8 @@ def __init__( self.name = name self.filebuffername = name else: - self.name, self.filebuffername = name + self.name = name[0] + self.filebuffername = name[1] self.data = data if grid: if grid.defer_load and isinstance(data, np.ndarray): @@ -191,7 +199,7 @@ def __init__( else: raise ValueError("Unsupported mesh type. Choose either: 'spherical' or 'flat'") self.timestamps = timestamps - if type(interp_method) is dict: + if isinstance(interp_method, dict): if self.name in interp_method: self.interp_method = interp_method[self.name] else: @@ -203,11 +211,11 @@ def __init__( GridType.RectilinearSGrid, GridType.CurvilinearSGrid, ]: - logger.warning_once( + logger.warning_once( # type: ignore "General s-levels are not supported in B-grid. RectilinearSGrid and CurvilinearSGrid can still be used to deal with shaved cells, but the levels must be horizontal." ) - self.fieldset = None + self.fieldset: "FieldSet" | None = None if allow_time_extrapolation is None: self.allow_time_extrapolation = True if len(self.grid.time) == 1 else False else: @@ -215,7 +223,7 @@ def __init__( self.time_periodic = time_periodic if self.time_periodic is not False and self.allow_time_extrapolation: - logger.warning_once( + logger.warning_once( # type: ignore "allow_time_extrapolation and time_periodic cannot be used together.\n \ allow_time_extrapolation is set to False" ) @@ -268,8 +276,8 @@ def __init__( self._field_fb_class = kwargs.pop("FieldFileBuffer", None) self.netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") self.netcdf_decodewarning = kwargs.pop("netcdf_decodewarning", True) - self.loaded_time_indices = [] - self.creation_log = kwargs.pop("creation_log", "") + self.loaded_time_indices: Iterable[int] = [] + self.creation_log: str = kwargs.pop("creation_log", "") self.chunksize = kwargs.pop("chunksize", None) self.netcdf_chunkdims_name_map = kwargs.pop("chunkdims_name_map", None) self.grid.depth_field = kwargs.pop("depth_field", None) @@ -283,10 +291,10 @@ def __init__( # (data_full_zdim = grid.zdim if no indices are used, for A- and C-grids and for some B-grids). It is used for the B-grid, # since some datasets do not provide the deeper level of data (which is ignored by the interpolation). self.data_full_zdim = kwargs.pop("data_full_zdim", None) - self.data_chunks = [] # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays - self.c_data_chunks = [] # C-pointers to the data_chunks array - self.nchunks = [] - self.chunk_set = False + self.data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays + self.c_data_chunks: list["PointerType" | None] = [] # C-pointers to the data_chunks array + self.nchunks: tuple[int, ...] = () + self.chunk_set: bool = False self.filebuffers = [None] * 2 if len(kwargs) > 0: raise SyntaxError(f'Field received an unexpected keyword argument "{list(kwargs.keys())[0]}"') @@ -349,13 +357,13 @@ def from_netcdf( dimensions, indices=None, grid=None, - mesh="spherical", + mesh: Mesh = "spherical", timestamps=None, - allow_time_extrapolation=None, - time_periodic=False, - deferred_load=True, + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + deferred_load: bool = True, **kwargs, - ): + ) -> "Field": """Create field from netCDF file. Parameters @@ -481,7 +489,7 @@ def from_netcdf( "if you would need such feature implemented." ) - interp_method = kwargs.pop("interp_method", "linear") + interp_method: InterpMethod = kwargs.pop("interp_method", "linear") if type(interp_method) is dict: if variable[0] in interp_method: interp_method = interp_method[variable[0]] @@ -542,11 +550,11 @@ def from_netcdf( ) kwargs["dataFiles"] = dataFiles - chunksize = kwargs.get("chunksize", None) + chunksize: bool | None = kwargs.get("chunksize", None) grid.chunksize = chunksize if "time" in indices: - logger.warning_once("time dimension in indices is not necessary anymore. It is then ignored.") + logger.warning_once("time dimension in indices is not necessary anymore. It is then ignored.") # type: ignore if "full_load" in kwargs: # for backward compatibility with Parcels < v2.0.0 deferred_load = not kwargs["full_load"] @@ -554,6 +562,7 @@ def from_netcdf( if grid.time.size <= 2 or deferred_load is False: deferred_load = False + _field_fb_class: Type[DeferredDaskFileBuffer | DaskFileBuffer | DeferredNetcdfFileBuffer | NetcdfFileBuffer] if chunksize not in [False, None]: if deferred_load: _field_fb_class = DeferredDaskFileBuffer @@ -570,7 +579,7 @@ def from_netcdf( data_list = [] ti = 0 for tslice, fname in zip(grid.timeslices, data_filenames, strict=True): - with _field_fb_class( + with _field_fb_class( # type: ignore[operator] fname, dimensions, indices, @@ -637,7 +646,14 @@ def from_netcdf( @classmethod def from_xarray( - cls, da, name, dimensions, mesh="spherical", allow_time_extrapolation=None, time_periodic=False, **kwargs + cls, + da: xr.DataArray, + name: str, + dimensions, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + **kwargs, ): """Create field from xarray Variable. @@ -854,7 +870,9 @@ def search_indices_vertical_z(self, z): zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) return (zi, zeta) - def search_indices_vertical_s(self, x, y, z, xi, yi, xsi, eta, ti, time): + def search_indices_vertical_s( + self, x: float, y: float, z: float, xi: int, yi: int, xsi: float, eta: float, ti: int, time: float + ): grid = self.grid if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"]: xsi = 1 @@ -886,7 +904,7 @@ def search_indices_vertical_s(self, x, y, z, xi, yi, xsi, eta, ti, time): + xsi * eta * grid.depth[:, yi + 1, xi + 1] + (1 - xsi) * eta * grid.depth[:, yi + 1, xi] ) - z = np.float32(z) + z = np.float32(z) # type: ignore # TODO: remove type ignore once we migrate to float64 if depth_vector[-1] > depth_vector[0]: depth_indices = depth_vector <= z @@ -930,7 +948,7 @@ def reconnect_bnd_indices(self, xi, yi, xdim, ydim, sphere_mesh): xi = xdim - xi return xi, yi - def search_indices_rectilinear(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): + def search_indices_rectilinear(self, x: float, y: float, z: float, ti=-1, time=-1, particle=None, search2D=False): grid = self.grid if grid.xdim > 1 and (not grid.zonal_periodic): @@ -1684,18 +1702,18 @@ class VectorField: field defining the vertical component (default: None) """ - def __init__(self, name, U, V, W=None): + def __init__(self, name: str, U: Field, V: Field, W: Field | None = None): self.name = name self.U = U self.V = V self.W = W - self.vector_type = "3D" if W else "2D" + self.vector_type: VectorType = "3D" if W else "2D" self.gridindexingtype = U.gridindexingtype if self.U.interp_method == "cgrid_velocity": assert self.V.interp_method == "cgrid_velocity", "Interpolation methods of U and V are not the same." assert self._check_grid_dimensions(U.grid, V.grid), "Dimensions of U and V are not the same." - if self.vector_type == "3D": - assert self.W.interp_method == "cgrid_velocity", "Interpolation methods of U and W are not the same." + if W is not None: + assert W.interp_method == "cgrid_velocity", "Interpolation methods of U and W are not the same." assert self._check_grid_dimensions(U.grid, W.grid), "Dimensions of U and W are not the same." @staticmethod @@ -1707,7 +1725,7 @@ def _check_grid_dimensions(grid1, grid2): and np.allclose(grid1.time_full, grid2.time_full) ) - def dist(self, lon1, lon2, lat1, lat2, mesh, lat): + def dist(self, lon1: float, lon2: float, lat1: float, lat2: float, mesh: Mesh, lat: float): if mesh == "spherical": rad = np.pi / 180.0 deg2m = 1852 * 60.0 @@ -1715,7 +1733,7 @@ def dist(self, lon1, lon2, lat1, lat2, mesh, lat): else: return np.sqrt((lon2 - lon1) ** 2 + (lat2 - lat1) ** 2) - def jacobian(self, xsi, eta, px, py): + def jacobian(self, xsi: float, eta: float, px: np.ndarray, py: np.ndarray): dphidxsi = [eta - 1, 1 - eta, eta, -eta] dphideta = [xsi - 1, -xsi, xsi, 1 - xsi] @@ -2285,7 +2303,7 @@ class NestedField(list): """ - def __init__(self, name, F, V=None, W=None): + def __init__(self, name: str, F, V=None, W=None): if V is None: if isinstance(F[0], VectorField): vector_type = F[0].vector_type diff --git a/parcels/fieldfilebuffer.py b/parcels/fieldfilebuffer.py index dd702270a..c3b45fc51 100644 --- a/parcels/fieldfilebuffer.py +++ b/parcels/fieldfilebuffer.py @@ -9,6 +9,7 @@ from dask import utils as da_utils from netCDF4 import Dataset as ncDataset +from parcels._typing import InterpMethodOption from parcels.tools.converters import convert_xarray_time_units from parcels.tools.loggers import logger from parcels.tools.statuscodes import DaskChunkingError @@ -16,7 +17,14 @@ class _FileBuffer: def __init__( - self, filename, dimensions, indices, timestamp=None, interp_method="linear", data_full_zdim=None, **kwargs + self, + filename, + dimensions, + indices, + timestamp=None, + interp_method: InterpMethodOption = "linear", + data_full_zdim=None, + **kwargs, ): self.filename = filename self.dimensions = dimensions # Dict with dimension keys for file data diff --git a/parcels/fieldset.py b/parcels/fieldset.py index fda0d4701..16b690f5c 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -7,19 +7,16 @@ import dask.array as da import numpy as np +from parcels._compat import MPI +from parcels._typing import GridIndexingType, InterpMethodOption, Mesh, TimePeriodic from parcels.field import DeferredArray, Field, NestedField, VectorField from parcels.grid import Grid from parcels.gridset import GridSet +from parcels.particlefile import ParticleFile from parcels.tools.converters import TimeConverter, convert_xarray_time_units from parcels.tools.loggers import logger from parcels.tools.statuscodes import TimeExtrapolationError -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None - - __all__ = ["FieldSet"] @@ -37,13 +34,14 @@ class FieldSet: in custom kernels. """ - def __init__(self, U, V, fields=None): + def __init__(self, U: Field | NestedField | None, V: Field | NestedField | None, fields=None): self.gridset = GridSet() - self.completed = False - self.particlefile = None + self.completed: bool = False + self.particlefile: ParticleFile | None = None if U: self.add_field(U, "U") - self.time_origin = self.U.grid.time_origin if isinstance(self.U, Field) else self.U[0].grid.time_origin + # see #1663 for type-ignore reason + self.time_origin = self.U.grid.time_origin if isinstance(self.U, Field) else self.U[0].grid.time_origin # type: ignore if V: self.add_field(V, "V") @@ -67,9 +65,9 @@ def from_data( data, dimensions, transpose=False, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, **kwargs, ): """Initialise FieldSet object from raw data. @@ -136,7 +134,7 @@ def from_data( lat = dims["lat"] depth = np.zeros(1, dtype=np.float32) if "depth" not in dims else dims["depth"] time = np.zeros(1, dtype=np.float64) if "time" not in dims else dims["time"] - time = np.array(time) if not isinstance(time, np.ndarray) else time + time = np.array(time) if isinstance(time[0], np.datetime64): time_origin = TimeConverter(time[0]) time = np.array([time_origin.reltime(t) for t in time]) @@ -159,7 +157,7 @@ def from_data( v = fields.pop("V", None) return cls(u, v, fields=fields) - def add_field(self, field, name=None): + def add_field(self, field: Field | NestedField, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. Parameters @@ -167,7 +165,8 @@ def add_field(self, field, name=None): field : parcels.field.Field Field object to be added name : str - Name of the :class:`parcels.field.Field` object to be added + Name of the :class:`parcels.field.Field` object to be added. Defaults + to name in Field object. Examples @@ -184,6 +183,7 @@ def add_field(self, field, name=None): "FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?" ) name = field.name if name is None else name + if hasattr(self, name): # check if Field with same name already exists when adding new Field raise RuntimeError(f"FieldSet already has a Field with name '{name}'") if isinstance(field, NestedField): @@ -196,7 +196,7 @@ def add_field(self, field, name=None): self.gridset.add_grid(field) field.fieldset = self - def add_constant_field(self, name, value, mesh="flat"): + def add_constant_field(self, name: str, value: float, mesh: Mesh = "flat"): """Wrapper function to add a Field that is constant in space, useful e.g. when using constant horizontal diffusivity @@ -342,10 +342,10 @@ def from_netcdf( dimensions, indices=None, fieldtype=None, - mesh="spherical", + mesh: Mesh = "spherical", timestamps=None, - allow_time_extrapolation=None, - time_periodic=False, + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, deferred_load=True, chunksize=None, **kwargs, @@ -415,7 +415,7 @@ def from_netcdf( ``{parcels_varname: {netcdf_dimname : (parcels_dimname, chunksize_as_int)}, ...}``, where ``parcels_dimname`` is one of ('time', 'depth', 'lat', 'lon') netcdf_engine : engine to use for netcdf reading in xarray. Default is 'netcdf', - but in cases where this doesn't work, setting netcdf_engine='scipy' could help + but in cases where this doesn't work, setting netcdf_engine='scipy' could help. Accepted options are the same as the ``engine`` parameter in ``xarray.open_dataset()``. **kwargs : Keyword arguments passed to the :class:`parcels.Field` constructor. @@ -435,10 +435,10 @@ def from_netcdf( """ # Ensure that times are not provided both in netcdf file and in 'timestamps'. if timestamps is not None and "time" in dimensions: - logger.warning_once("Time already provided, defaulting to dimensions['time'] over timestamps.") + logger.warning_once("Time already provided, defaulting to dimensions['time'] over timestamps.") # type: ignore timestamps = None - fields = {} + fields: dict[str, Field] = {} if "creation_log" not in kwargs.keys(): kwargs["creation_log"] = "from_netcdf" for var, name in variables.items(): @@ -521,10 +521,10 @@ def from_nemo( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, - tracer_interp_method="cgrid_tracer", + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + tracer_interp_method: InterpMethodOption = "cgrid_tracer", chunksize=None, **kwargs, ): @@ -632,10 +632,10 @@ def from_mitgcm( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, - tracer_interp_method="cgrid_tracer", + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + tracer_interp_method: InterpMethodOption = "cgrid_tracer", chunksize=None, **kwargs, ): @@ -682,11 +682,11 @@ def from_c_grid_dataset( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, - tracer_interp_method="cgrid_tracer", - gridindexingtype="nemo", + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + tracer_interp_method: InterpMethodOption = "cgrid_tracer", + gridindexingtype: GridIndexingType = "nemo", chunksize=None, **kwargs, ): @@ -764,12 +764,12 @@ def from_c_grid_dataset( if "U" in dimensions and "V" in dimensions and dimensions["U"] != dimensions["V"]: raise ValueError( "On a C-grid, the dimensions of velocities should be the corners (f-points) of the cells, so the same for U and V. " - "See also ../examples/documentation_indexing.ipynb" + "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" ) if "U" in dimensions and "W" in dimensions and dimensions["U"] != dimensions["W"]: raise ValueError( "On a C-grid, the dimensions of velocities should be the corners (f-points) of the cells, so the same for U, V and W. " - "See also ../examples/documentation_indexing.ipynb" + "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" ) if "interp_method" in kwargs.keys(): raise TypeError("On a C-grid, the interpolation method for velocities should not be overridden") @@ -804,10 +804,10 @@ def from_pop( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, - tracer_interp_method="bgrid_tracer", + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + tracer_interp_method: InterpMethodOption = "bgrid_tracer", chunksize=None, depth_units="m", **kwargs, @@ -909,7 +909,7 @@ def from_pop( if hasattr(fieldset, "W"): if depth_units == "m": fieldset.W.set_scaling_factor(-0.01) # cm/s to m/s and change the W direction - logger.warning_once( + logger.warning_once( # type: ignore "Parcels assumes depth in POP output to be in 'm'. Use depth_units='cm' if the output depth is in 'cm'." ) elif depth_units == "cm": @@ -925,10 +925,10 @@ def from_mom5( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, - tracer_interp_method="bgrid_tracer", + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + tracer_interp_method: InterpMethodOption = "bgrid_tracer", chunksize=None, **kwargs, ): @@ -1049,10 +1049,10 @@ def from_b_grid_dataset( variables, dimensions, indices=None, - mesh="spherical", - allow_time_extrapolation=None, - time_periodic=False, - tracer_interp_method="bgrid_tracer", + mesh: Mesh = "spherical", + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, + tracer_interp_method: InterpMethodOption = "bgrid_tracer", chunksize=None, **kwargs, ): @@ -1126,12 +1126,12 @@ def from_b_grid_dataset( if "U" in dimensions and "V" in dimensions and dimensions["U"] != dimensions["V"]: raise ValueError( "On a B-grid, the dimensions of velocities should be the (top) corners of the grid cells, so the same for U and V. " - "See also ../examples/documentation_indexing.ipynb" + "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" ) if "U" in dimensions and "W" in dimensions and dimensions["U"] != dimensions["W"]: raise ValueError( "On a B-grid, the dimensions of velocities should be the (top) corners of the grid cells, so the same for U, V and W. " - "See also ../examples/documentation_indexing.ipynb" + "See also https://docs.oceanparcels.org/en/latest/examples/documentation_indexing.html" ) interp_method = {} @@ -1166,8 +1166,8 @@ def from_parcels( vvar="vomecrty", indices=None, extra_fields=None, - allow_time_extrapolation=None, - time_periodic=False, + allow_time_extrapolation: bool | None = None, + time_periodic: TimePeriodic = False, deferred_load=True, chunksize=None, **kwargs, diff --git a/parcels/grid.py b/parcels/grid.py index 2ba8073c5..27cb3a285 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -3,7 +3,9 @@ from enum import IntEnum import numpy as np +import numpy.typing as npt +from parcels._typing import Mesh from parcels.tools.converters import TimeConverter from parcels.tools.loggers import logger @@ -38,7 +40,14 @@ class CGrid(Structure): class Grid: """Grid class that defines a (spatial and temporal) grid on which Fields are defined.""" - def __init__(self, lon, lat, time, time_origin, mesh): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + time: npt.NDArray | None, + time_origin: TimeConverter | None, + mesh: Mesh, + ): self.xi = None self.yi = None self.zi = None @@ -66,7 +75,7 @@ def __init__(self, lon, lat, time, time_origin, mesh): assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object" self.mesh = mesh self.cstruct = None - self.cell_edge_sizes = {} + self.cell_edge_sizes: dict[str, npt.NDArray] = {} self.zonal_periodic = False self.zonal_halo = 0 self.meridional_halo = 0 @@ -76,20 +85,28 @@ def __init__(self, lon, lat, time, time_origin, mesh): [np.nanmin(lon), np.nanmax(lon), np.nanmin(lat), np.nanmax(lat)], dtype=np.float32 ) self.periods = 0 - self.load_chunk = [] + self.load_chunk: npt.NDArray = np.array([]) self.chunk_info = None self.chunksize = None self._add_last_periodic_data_timestep = False self.depth_field = None @staticmethod - def create_grid(lon, lat, depth, time, time_origin, mesh, **kwargs): - if not isinstance(lon, np.ndarray): - lon = np.array(lon) - if not isinstance(lat, np.ndarray): - lat = np.array(lat) - if not (depth is None or isinstance(depth, np.ndarray)): + def create_grid( + lon: npt.ArrayLike, + lat: npt.ArrayLike, + depth, + time, + time_origin, + mesh: Mesh, + **kwargs, + ): + lon = np.array(lon) + lat = np.array(lat) + + if depth is not None: depth = np.array(depth) + if len(lon.shape) <= 1: if depth is None or len(depth.shape) <= 1: return RectilinearZGrid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh, **kwargs) @@ -313,7 +330,7 @@ class RectilinearGrid(Grid): """ - def __init__(self, lon, lat, time, time_origin, mesh): + def __init__(self, lon, lat, time, time_origin, mesh: Mesh): assert isinstance(lon, np.ndarray) and len(lon.shape) <= 1, "lon is not a numpy vector" assert isinstance(lat, np.ndarray) and len(lat.shape) <= 1, "lat is not a numpy vector" assert isinstance(time, np.ndarray) or not time, "time is not a numpy array" @@ -327,7 +344,7 @@ def __init__(self, lon, lat, time, time_origin, mesh): if self.ydim > 1 and self.lat[-1] < self.lat[0]: self.lat = np.flip(self.lat, axis=0) self.lat_flipped = True - logger.warning_once( + logger.warning_once( # type: ignore "Flipping lat data from North-South to South-North. " "Note that this may lead to wrong sign for meridional velocity, so tread very carefully" ) @@ -396,7 +413,7 @@ class RectilinearZGrid(RectilinearGrid): 2. flat: No conversion, lat/lon are assumed to be in m. """ - def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh="flat"): + def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh: Mesh = "flat"): super().__init__(lon, lat, time, time_origin, mesh) if isinstance(depth, np.ndarray): assert len(depth.shape) <= 1, "depth is not a vector" @@ -442,7 +459,15 @@ class RectilinearSGrid(RectilinearGrid): 2. flat: No conversion, lat/lon are assumed to be in m. """ - def __init__(self, lon, lat, depth, time=None, time_origin=None, mesh="flat"): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + depth: npt.NDArray, + time: npt.NDArray | None = None, + time_origin: TimeConverter | None = None, + mesh: Mesh = "flat", + ): super().__init__(lon, lat, time, time_origin, mesh) assert isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4], "depth is not a 3D or 4D numpy array" @@ -477,7 +502,14 @@ def __init__(self, lon, lat, depth, time=None, time_origin=None, mesh="flat"): class CurvilinearGrid(Grid): - def __init__(self, lon, lat, time=None, time_origin=None, mesh="flat"): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + time: npt.NDArray | None = None, + time_origin: TimeConverter | None = None, + mesh: Mesh = "flat", + ): assert isinstance(lon, np.ndarray) and len(lon.squeeze().shape) == 2, "lon is not a 2D numpy array" assert isinstance(lat, np.ndarray) and len(lat.squeeze().shape) == 2, "lat is not a 2D numpy array" assert isinstance(time, np.ndarray) or not time, "time is not a numpy array" @@ -574,7 +606,15 @@ class CurvilinearZGrid(CurvilinearGrid): 2. flat: No conversion, lat/lon are assumed to be in m. """ - def __init__(self, lon, lat, depth=None, time=None, time_origin=None, mesh="flat"): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + depth: npt.NDArray | None = None, + time: npt.NDArray | None = None, + time_origin: TimeConverter | None = None, + mesh: Mesh = "flat", + ): super().__init__(lon, lat, time, time_origin, mesh) if isinstance(depth, np.ndarray): assert len(depth.shape) == 1, "depth is not a vector" @@ -619,7 +659,15 @@ class CurvilinearSGrid(CurvilinearGrid): 2. flat: No conversion, lat/lon are assumed to be in m. """ - def __init__(self, lon, lat, depth, time=None, time_origin=None, mesh="flat"): + def __init__( + self, + lon: npt.NDArray, + lat: npt.NDArray, + depth: npt.NDArray, + time: npt.NDArray | None = None, + time_origin: TimeConverter | None = None, + mesh: Mesh = "flat", + ): super().__init__(lon, lat, time, time_origin, mesh) assert isinstance(depth, np.ndarray) and len(depth.shape) in [3, 4], "depth is not a 4D numpy array" diff --git a/parcels/interaction/interactionkernel.py b/parcels/interaction/interactionkernel.py index db4c7f897..dce302e21 100644 --- a/parcels/interaction/interactionkernel.py +++ b/parcels/interaction/interactionkernel.py @@ -4,11 +4,7 @@ import numpy as np -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None - +from parcels._compat import MPI from parcels.field import NestedField, VectorField from parcels.kernel import BaseKernel from parcels.tools.loggers import logger @@ -36,7 +32,7 @@ def __init__( py_ast=None, funcvars=None, c_include="", - delete_cfiles=True, + delete_cfiles: bool = True, ): if MPI is not None and MPI.COMM_WORLD.Get_size() > 1: raise NotImplementedError( diff --git a/parcels/kernel.py b/parcels/kernel.py index b144abc72..fff3571d1 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -1,4 +1,5 @@ import _ctypes +import abc import ast import functools import hashlib @@ -17,13 +18,9 @@ import numpy.ctypeslib as npct from numpy import ndarray -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None - import parcels.rng as ParcelsRandom # noqa from parcels import rng # noqa +from parcels._compat import MPI from parcels.application_kernels.advection import ( AdvectionAnalytical, AdvectionRK4_3D, @@ -45,7 +42,7 @@ __all__ = ["Kernel", "BaseKernel"] -class BaseKernel: +class BaseKernel(abc.ABC): """Superclass for 'normal' and Interactive Kernels""" def __init__( @@ -139,6 +136,9 @@ def remove_deleted(self, pset): self.fieldset.particlefile.write(pset, None, indices=indices) pset.remove_indices(indices) + @abc.abstractmethod + def get_kernel_compile_files(self): ... + class Kernel(BaseKernel): """Kernel object that encapsulates auto-generated code. diff --git a/parcels/particledata.py b/parcels/particledata.py index 1f32cfa26..2d5964f9c 100644 --- a/parcels/particledata.py +++ b/parcels/particledata.py @@ -3,19 +3,10 @@ import numpy as np +from parcels._compat import MPI, KMeans from parcels.tools.loggers import logger from parcels.tools.statuscodes import StatusCode -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None -if MPI: - try: - from sklearn.cluster import KMeans - except: - KMeans = None - def partitionParticlesMPI_default(coords, mpi_size=1): """This function takes the coordinates of the particle starting diff --git a/parcels/particlefile.py b/parcels/particlefile.py index 6e047130b..588e597d3 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -8,14 +8,9 @@ import zarr import parcels +from parcels._compat import MPI from parcels.tools.loggers import logger -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None - - __all__ = ["ParticleFile"] diff --git a/parcels/particleset.py b/parcels/particleset.py index c4007cf20..5d9c6a571 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -9,12 +9,7 @@ from scipy.spatial import KDTree from tqdm import tqdm -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None - - +from parcels._compat import MPI from parcels.application_kernels.advection import AdvectionRK4 from parcels.compilation.codecompiler import GNUCompiler from parcels.field import NestedField diff --git a/parcels/tools/converters.py b/parcels/tools/converters.py index c8ccf8003..de6705fac 100644 --- a/parcels/tools/converters.py +++ b/parcels/tools/converters.py @@ -5,6 +5,7 @@ import cftime import numpy as np +import numpy.typing as npt import xarray as xr __all__ = [ @@ -20,20 +21,15 @@ ] -def convert_to_flat_array(var): +def convert_to_flat_array(var: npt.ArrayLike) -> npt.NDArray: """Convert lists and single integers/floats to one-dimensional numpy arrays Parameters ---------- - var : np.ndarray, float or array_like + var : Array list or numeric to convert to a one-dimensional numpy array """ - if isinstance(var, np.ndarray): - return var.flatten() - elif isinstance(var, (int, float, np.float32, np.int32)): - return np.array([var]) - else: - return np.array(var) + return np.array(var).flatten() def _get_cftime_datetimes(): @@ -167,8 +163,8 @@ def __le__(self, other): class UnitConverter: """Interface class for spatial unit conversion during field sampling that performs no conversion.""" - source_unit = None - target_unit = None + source_unit: str | None = None + target_unit: str | None = None def to_target(self, value, x, y, z): return value diff --git a/parcels/tools/global_statics.py b/parcels/tools/global_statics.py index 0e97bac0d..896d3195f 100644 --- a/parcels/tools/global_statics.py +++ b/parcels/tools/global_statics.py @@ -8,7 +8,7 @@ from os import getuid except: # Windows does not have getuid(), so define to simply return 'tmp' - def getuid(): + def getuid(): # type: ignore return "tmp" diff --git a/parcels/tools/interpolation_utils.py b/parcels/tools/interpolation_utils.py index 27ab24c7f..273dbbec8 100644 --- a/parcels/tools/interpolation_utils.py +++ b/parcels/tools/interpolation_utils.py @@ -1,17 +1,19 @@ +from typing import Callable, Literal + import numpy as np -__all__ = [] +from parcels._typing import Mesh +__all__ = [] # type: ignore -# fmt: off -def phi1D_lin(xsi): - phi = [1-xsi, - xsi] +def phi1D_lin(xsi: float) -> list[float]: + phi = [1 - xsi, xsi] return phi -def phi1D_quad(xsi): +# fmt: off +def phi1D_quad(xsi: float) -> list[float]: phi = [2*xsi**2-3*xsi+1, -4*xsi**2+4*xsi, 2*xsi**2-xsi] @@ -19,7 +21,8 @@ def phi1D_quad(xsi): return phi -def phi2D_lin(xsi, eta): + +def phi2D_lin(xsi: float, eta: float) -> list[float]: phi = [(1-xsi) * (1-eta), xsi * (1-eta), xsi * eta , @@ -28,7 +31,7 @@ def phi2D_lin(xsi, eta): return phi -def phi3D_lin(xsi, eta, zet): +def phi3D_lin(xsi: float, eta: float, zet: float) -> list[float]: phi = [(1-xsi) * (1-eta) * (1-zet), xsi * (1-eta) * (1-zet), xsi * eta * (1-zet), @@ -41,7 +44,7 @@ def phi3D_lin(xsi, eta, zet): return phi -def dphidxsi3D_lin(xsi, eta, zet): +def dphidxsi3D_lin(xsi: float, eta: float, zet: float) -> tuple[list[float], list[float], list[float]]: dphidxsi = [ - (1-eta) * (1-zet), (1-eta) * (1-zet), ( eta) * (1-zet), @@ -70,7 +73,9 @@ def dphidxsi3D_lin(xsi, eta, zet): return dphidxsi, dphideta, dphidzet -def dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh): +def dxdxsi3D_lin( + hexa_x: list[float], hexa_y: list[float], hexa_z: list[float], xsi: float, eta: float, zet: float, mesh: Mesh +) -> tuple[float, float, float, float, float, float, float, float, float]: dphidxsi, dphideta, dphidzet = dphidxsi3D_lin(xsi, eta, zet) if mesh == 'spherical': @@ -99,16 +104,29 @@ def dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh): return dxdxsi, dxdeta, dxdzet, dydxsi, dydeta, dydzet, dzdxsi, dzdeta, dzdzet -def jacobian3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh): +def jacobian3D_lin( + hexa_x: list[float], hexa_y: list[float], hexa_z: list[float], xsi: float, eta: float, zet: float, mesh: Mesh +) -> float: dxdxsi, dxdeta, dxdzet, dydxsi, dydeta, dydzet, dzdxsi, dzdeta, dzdzet = dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh) - jac = dxdxsi * (dydeta*dzdzet - dzdeta*dydzet)\ - - dxdeta * (dydxsi*dzdzet - dzdxsi*dydzet)\ - + dxdzet * (dydxsi*dzdeta - dzdxsi*dydeta) + jac = ( + dxdxsi * (dydeta * dzdzet - dzdeta * dydzet) + - dxdeta * (dydxsi * dzdzet - dzdxsi * dydzet) + + dxdzet * (dydxsi * dzdeta - dzdxsi * dydeta) + ) return jac -def jacobian3D_lin_face(hexa_x, hexa_y, hexa_z, xsi, eta, zet, orientation, mesh): +def jacobian3D_lin_face( + hexa_x: list[float], + hexa_y: list[float], + hexa_z: list[float], + xsi: float, + eta: float, + zet: float, + orientation: Literal["zonal", "meridional", "vertical"], + mesh: Mesh, +) -> float: dxdxsi, dxdeta, dxdzet, dydxsi, dydeta, dydzet, dzdxsi, dzdeta, dzdzet = dxdxsi3D_lin(hexa_x, hexa_y, hexa_z, xsi, eta, zet, mesh) if orientation == 'zonal': @@ -128,7 +146,7 @@ def jacobian3D_lin_face(hexa_x, hexa_y, hexa_z, xsi, eta, zet, orientation, mesh return jac -def dphidxsi2D_lin(xsi, eta): +def dphidxsi2D_lin(xsi: float, eta: float) -> tuple[list[float], list[float]]: dphidxsi = [-(1-eta), 1-eta, eta, @@ -141,7 +159,12 @@ def dphidxsi2D_lin(xsi, eta): return dphidxsi, dphideta -def dxdxsi2D_lin(quad_x, quad_y, xsi, eta,): +def dxdxsi2D_lin( + quad_x, + quad_y, + xsi: float, + eta: float, +): dphidxsi, dphideta = dphidxsi2D_lin(xsi, eta) dxdxsi = np.dot(quad_x, dphidxsi) @@ -152,20 +175,21 @@ def dxdxsi2D_lin(quad_x, quad_y, xsi, eta,): return dxdxsi, dxdeta, dydxsi, dydeta -def jacobian2D_lin(quad_x, quad_y, xsi, eta): +def jacobian2D_lin(quad_x, quad_y, xsi: float, eta: float): dxdxsi, dxdeta, dydxsi, dydeta = dxdxsi2D_lin(quad_x, quad_y, xsi, eta) - jac = dxdxsi*dydeta - dxdeta*dydxsi + jac = dxdxsi * dydeta - dxdeta * dydxsi return jac def length2d_lin_edge(quad_x, quad_y, ids): xe = [quad_x[ids[0]], quad_x[ids[1]]] ye = [quad_y[ids[0]], quad_y[ids[1]]] - return np.sqrt((xe[1]-xe[0])**2+(ye[1]-ye[0])**2) + return np.sqrt((xe[1] - xe[0]) ** 2 + (ye[1] - ye[0]) ** 2) -def interpolate(phi, f, xsi): +def interpolate(phi: Callable[[float], list[float]], f: list[float], xsi: float) -> float: return np.dot(phi(xsi), f) + # fmt: on diff --git a/parcels/tools/loggers.py b/parcels/tools/loggers.py index b21dee81a..335bc3a8d 100644 --- a/parcels/tools/loggers.py +++ b/parcels/tools/loggers.py @@ -40,10 +40,10 @@ def info_once(self, message, *args, **kws): logger.addHandler(handler) logging.addLevelName(warning_once_level, "WARNING") -logging.Logger.warning_once = warning_once +logging.Logger.warning_once = warning_once # type: ignore logging.addLevelName(info_once_level, "INFO") -logging.Logger.info_once = info_once +logging.Logger.info_once = info_once # type: ignore dup_filter = DuplicateFilter() logger.addFilter(dup_filter) diff --git a/parcels/tools/timer.py b/parcels/tools/timer.py index 02daf75ab..8896aa42a 100644 --- a/parcels/tools/timer.py +++ b/parcels/tools/timer.py @@ -1,12 +1,9 @@ import datetime import time -try: - from mpi4py import MPI -except ModuleNotFoundError: - MPI = None +from parcels._compat import MPI -__all__ = [] +__all__ = [] # type: ignore class Timer: diff --git a/pyproject.toml b/pyproject.toml index d42401fbe..134a6f8ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ local_scheme = "no-local-version" python_files = ["test_*.py", "example_*.py", "*tutorial*"] [tool.ruff] -target-version = "py310" line-length = 120 [tool.ruff.lint] @@ -121,3 +120,27 @@ ignore = [ [tool.ruff.lint.pydocstyle] convention = "numpy" + +[tool.mypy] +files = [ + "parcels/compilation/codegenerator.py", + "parcels/_typing.py", + "parcels/tools/*.py", + "parcels/grid.py", + "parcels/field.py", + "parcels/fieldset.py", +] + +[[tool.mypy.overrides]] +module = [ + "parcels._version_setup", + "mpi4py", + "scipy.spatial", + "sklearn.cluster", + "zarr", + "cftime", + "pykdtree.kdtree", + "netCDF4", + "cgen" +] +ignore_missing_imports = true diff --git a/tests/test_data/create_testfields.py b/tests/test_data/create_testfields.py index 3c7e783fa..acf192158 100644 --- a/tests/test_data/create_testfields.py +++ b/tests/test_data/create_testfields.py @@ -48,7 +48,6 @@ def generate_perlin_testfield(): lon = np.linspace(-180.0, 180.0, img_shape[0], dtype=np.float32) lat = np.linspace(-90.0, 90.0, img_shape[1], dtype=np.float32) time = np.zeros(1, dtype=np.float64) - time = np.array(time) if not isinstance(time, np.ndarray) else time # Define arrays U (zonal), V (meridional), W (vertical) and P (sea # surface height) all on A-grid