diff --git a/docs/examples/example_nemo_curvilinear.py b/docs/examples/example_nemo_curvilinear.py index d2496943d..d5a5b5c19 100644 --- a/docs/examples/example_nemo_curvilinear.py +++ b/docs/examples/example_nemo_curvilinear.py @@ -108,7 +108,7 @@ def test_nemo_3D_samegrid(): fieldset = parcels.FieldSet.from_nemo(filenames, variables, dimensions) - assert fieldset.U.dataFiles is not fieldset.W.dataFiles + assert fieldset.U._dataFiles is not fieldset.W._dataFiles def main(args=None): diff --git a/parcels/_typing.py b/parcels/_typing.py index a1e24a6e3..533170acc 100644 --- a/parcels/_typing.py +++ b/parcels/_typing.py @@ -10,7 +10,7 @@ import datetime import os from collections.abc import Callable -from typing import Literal +from typing import Any, Literal, get_args class ParcelsAST(ast.AST): @@ -37,10 +37,33 @@ class ParcelsAST(ast.AST): Mesh = Literal["spherical", "flat"] # corresponds with `mesh` VectorType = Literal["3D", "2D"] | None # corresponds with `vector_type` ChunkMode = Literal["auto", "specific", "failsafe"] # corresponds with `chunk_mode` -GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `grid_indexing_type` +GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `gridindexingtype` UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `update_status` TimePeriodic = float | datetime.timedelta | Literal[False] # corresponds with `update_status` NetcdfEngine = Literal["netcdf4", "xarray", "scipy"] KernelFunction = Callable[..., None] + + +def _validate_against_pure_literal(value, typing_literal): + """Uses a Literal type alias to validate. + + Can't be used with ``Literal[...] | None`` etc. as its not a pure literal. + """ + if value not in get_args(typing_literal): + msg = f"Invalid value {value!r}. Valid options are {get_args(typing_literal)!r}" + raise ValueError(msg) + + +# Assertion functions to clean user input +def assert_valid_interp_method(value: Any): + _validate_against_pure_literal(value, InterpMethodOption) + + +def assert_valid_mesh(value: Any): + _validate_against_pure_literal(value, Mesh) + + +def assert_valid_gridindexingtype(value: Any): + _validate_against_pure_literal(value, GridIndexingType) diff --git a/parcels/application_kernels/advection.py b/parcels/application_kernels/advection.py index 9defaaad1..ff848bb02 100644 --- a/parcels/application_kernels/advection.py +++ b/parcels/application_kernels/advection.py @@ -125,14 +125,14 @@ def AdvectionAnalytical(particle, fieldset, time): direction = 1.0 if particle.dt > 0 else -1.0 withW = True if "W" in [f.name for f in fieldset.get_fields()] else False withTime = True if len(fieldset.U.grid.time_full) > 1 else False - ti = fieldset.U.time_index(time)[0] + ti = fieldset.U._time_index(time)[0] ds_t = particle.dt if withTime: tau = (time - fieldset.U.grid.time[ti]) / (fieldset.U.grid.time[ti + 1] - fieldset.U.grid.time[ti]) time_i = np.linspace(0, fieldset.U.grid.time[ti + 1] - fieldset.U.grid.time[ti], I_s) ds_t = min(ds_t, time_i[np.where(time - fieldset.U.grid.time[ti] < time_i)[0][0]]) - xsi, eta, zeta, xi, yi, zi = fieldset.U.search_indices( + xsi, eta, zeta, xi, yi, zi = fieldset.U._search_indices( particle.lon, particle.lat, particle.depth, particle=particle ) if withW: diff --git a/parcels/compilation/codegenerator.py b/parcels/compilation/codegenerator.py index e7eef09cc..b71936ddb 100644 --- a/parcels/compilation/codegenerator.py +++ b/parcels/compilation/codegenerator.py @@ -819,14 +819,14 @@ def visit_FieldEvalNode(self, node): self.visit(node.field) self.visit(node.args) args = self._check_FieldSamplingArguments(node.args.ccode) - ccode_eval = node.field.obj.ccode_eval(node.var, *args) + ccode_eval = node.field.obj._ccode_eval(node.var, *args) stmts = [ c.Assign("parcels_interp_state", ccode_eval), c.Assign("particles->state[pnum]", "max(particles->state[pnum], parcels_interp_state)"), ] if node.convert: - ccode_conv = node.field.obj.ccode_convert(*args) + ccode_conv = node.field.obj._ccode_convert(*args) conv_stat = c.Statement(f"{node.var} *= {ccode_conv}") stmts += [conv_stat] @@ -836,17 +836,17 @@ def visit_VectorFieldEvalNode(self, node): self.visit(node.field) self.visit(node.args) args = self._check_FieldSamplingArguments(node.args.ccode) - ccode_eval = node.field.obj.ccode_eval( + ccode_eval = node.field.obj._ccode_eval( node.var, node.var2, node.var3, node.field.obj.U, node.field.obj.V, node.field.obj.W, *args ) if node.convert and node.field.obj.U.interp_method != "cgrid_velocity": - ccode_conv1 = node.field.obj.U.ccode_convert(*args) - ccode_conv2 = node.field.obj.V.ccode_convert(*args) + ccode_conv1 = node.field.obj.U._ccode_convert(*args) + ccode_conv2 = node.field.obj.V._ccode_convert(*args) statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")] else: statements = [] if node.convert and node.field.obj.vector_type == "3D": - ccode_conv3 = node.field.obj.W.ccode_convert(*args) + ccode_conv3 = node.field.obj.W._ccode_convert(*args) statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}")) conv_stat = c.Block(statements) node.ccode = c.Block( @@ -864,8 +864,8 @@ def visit_NestedFieldEvalNode(self, node): cstat = [] args = self._check_FieldSamplingArguments(node.args.ccode) for fld in node.fields.obj: - ccode_eval = fld.ccode_eval(node.var, *args) - ccode_conv = fld.ccode_convert(*args) + ccode_eval = fld._ccode_eval(node.var, *args) + ccode_conv = fld._ccode_convert(*args) conv_stat = c.Statement(f"{node.var} *= {ccode_conv}") cstat += [ c.Assign("particles->state[pnum]", ccode_eval), @@ -884,15 +884,15 @@ def visit_NestedVectorFieldEvalNode(self, node): cstat = [] args = self._check_FieldSamplingArguments(node.args.ccode) for fld in node.fields.obj: - ccode_eval = fld.ccode_eval(node.var, node.var2, node.var3, fld.U, fld.V, fld.W, *args) + ccode_eval = fld._ccode_eval(node.var, node.var2, node.var3, fld.U, fld.V, fld.W, *args) if fld.U.interp_method != "cgrid_velocity": - ccode_conv1 = fld.U.ccode_convert(*args) - ccode_conv2 = fld.V.ccode_convert(*args) + ccode_conv1 = fld.U._ccode_convert(*args) + ccode_conv2 = fld.V._ccode_convert(*args) statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")] else: statements = [] if fld.vector_type == "3D": - ccode_conv3 = fld.W.ccode_convert(*args) + ccode_conv3 = fld.W._ccode_convert(*args) statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}")) cstat += [ c.Assign("particles->state[pnum]", ccode_eval), diff --git a/parcels/field.py b/parcels/field.py index 7546a2783..b3a63cc98 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -12,7 +12,16 @@ import xarray as xr import parcels.tools.interpolation_utils as i_u -from parcels._typing import GridIndexingType, InterpMethod, Mesh, TimePeriodic, VectorType +from parcels._typing import ( + GridIndexingType, + InterpMethod, + Mesh, + TimePeriodic, + VectorType, + assert_valid_gridindexingtype, + assert_valid_interp_method, +) +from parcels.tools._helpers import deprecated_made_private from parcels.tools.converters import ( Geographic, GeographicPolar, @@ -181,21 +190,15 @@ def __init__( raise ValueError( "Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately" ) - self.grid = grid + self._grid = grid else: if (time is not None) and isinstance(time[0], np.datetime64): time_origin = TimeConverter(time[0]) time = np.array([time_origin.reltime(t) for t in time]) else: time_origin = TimeConverter(0) - self.grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) + self._grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) self.igrid = -1 - # self.lon, self.lat, self.depth and self.time are not used any more in parcels. - # self.grid should be used instead. - # Those variables are still defined for backwards compatibility with users codes. - self.lon = self.grid.lon - self.lat = self.grid.lat - self.depth = self.grid.depth self.fieldtype = self.name if fieldtype is None else fieldtype self.to_write = to_write if self.grid.mesh == "flat" or (self.fieldtype not in unitconverters_map.keys()): @@ -212,7 +215,8 @@ def __init__( raise RuntimeError(f"interp_method is a dictionary but {name} is not in it") else: self.interp_method = interp_method - self.gridindexingtype = gridindexingtype + assert_valid_gridindexingtype(gridindexingtype) + self._gridindexingtype = gridindexingtype if self.interp_method in ["bgrid_velocity", "bgrid_w_velocity", "bgrid_tracer"] and self.grid.gtype in [ GridType.RectilinearSGrid, GridType.CurvilinearSGrid, @@ -253,14 +257,14 @@ def __init__( self.vmin = vmin self.vmax = vmax - self.cast_data_dtype = cast_data_dtype + self._cast_data_dtype = cast_data_dtype if self.cast_data_dtype == "float32": - self.cast_data_dtype = np.float32 + self._cast_data_dtype = np.float32 elif self.cast_data_dtype == "float64": - self.cast_data_dtype = np.float64 + self._cast_data_dtype = np.float64 if not self.grid.defer_load: - self.data = self.reshape(self.data, transpose) + self.data = self._reshape(self.data, transpose) # Hack around the fact that NaN and ridiculously large values # propagate in SciPy's interpolators @@ -277,15 +281,15 @@ def __init__( self._scaling_factor = None # Variable names in JIT code - self.dimensions = kwargs.pop("dimensions", None) + self._dimensions = kwargs.pop("dimensions", None) self.indices = kwargs.pop("indices", None) - self.dataFiles = kwargs.pop("dataFiles", None) - if self.grid._add_last_periodic_data_timestep and self.dataFiles is not None: - self.dataFiles = np.append(self.dataFiles, self.dataFiles[0]) + self._dataFiles = kwargs.pop("dataFiles", None) + if self.grid._add_last_periodic_data_timestep and self._dataFiles is not None: + self._dataFiles = np.append(self._dataFiles, self._dataFiles[0]) self._field_fb_class = kwargs.pop("FieldFileBuffer", None) - self.netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") - self.loaded_time_indices: Iterable[int] = [] # type: ignore - self.creation_log = kwargs.pop("creation_log", "") + self._netcdf_engine = kwargs.pop("netcdf_engine", "netcdf4") + self._loaded_time_indices: Iterable[int] = [] # type: ignore + self._creation_log = kwargs.pop("creation_log", "") self.chunksize = kwargs.pop("chunksize", None) self.netcdf_chunkdims_name_map = kwargs.pop("chunkdims_name_map", None) self.grid.depth_field = kwargs.pop("depth_field", None) @@ -299,16 +303,99 @@ def __init__( # (data_full_zdim = grid.zdim if no indices are used, for A- and C-grids and for some B-grids). It is used for the B-grid, # since some datasets do not provide the deeper level of data (which is ignored by the interpolation). self.data_full_zdim = kwargs.pop("data_full_zdim", None) - self.data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays - self.c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array + self._data_chunks = [] # type: ignore # the data buffer of the FileBuffer raw loaded data - shall be a list of C-contiguous arrays + self._c_data_chunks: list[PointerType | None] = [] # C-pointers to the data_chunks array self.nchunks: tuple[int, ...] = () - self.chunk_set: bool = False + self._chunk_set: bool = False self.filebuffers = [None] * 2 if len(kwargs) > 0: raise SyntaxError(f'Field received an unexpected keyword argument "{list(kwargs.keys())[0]}"') + @property + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def dataFiles(self): + return self._dataFiles + + @property + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def chunk_set(self): + return self._chunk_set + + @property + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def c_data_chunks(self): + return self._c_data_chunks + + @property + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def data_chunks(self): + return self._data_chunks + + @property + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def creation_log(self): + return self._creation_log + + @property + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def loaded_time_indices(self): + return self._loaded_time_indices + + @property + def dimensions(self): + return self._dimensions + + @property + def grid(self): + return self._grid + + @property + def lon(self): + """Lon defined on the Grid object""" + return self.grid.lon + + @property + def lat(self): + """Lat defined on the Grid object""" + return self.grid.lat + + @property + def depth(self): + """Depth defined on the Grid object""" + return self.grid.depth + + @property + def cell_edge_sizes(self): + return self.grid.cell_edge_sizes + + @property + def interp_method(self): + return self._interp_method + + @interp_method.setter + def interp_method(self, value): + assert_valid_interp_method(value) + self._interp_method = value + + @property + def gridindexingtype(self): + return self._gridindexingtype + + @property + def cast_data_dtype(self): + return self._cast_data_dtype + + @property + def netcdf_engine(self): + return self._netcdf_engine + @classmethod - def get_dim_filenames(cls, filenames, dim): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def get_dim_filenames(cls, *args, **kwargs): + return cls._get_dim_filenames(*args, **kwargs) + + @classmethod + def _get_dim_filenames(cls, filenames, dim): if isinstance(filenames, str) or not isinstance(filenames, collections.abc.Iterable): return [filenames] elif isinstance(filenames, dict): @@ -322,7 +409,12 @@ def get_dim_filenames(cls, filenames, dim): return filenames @staticmethod - def collect_timeslices( + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def collect_timeslices(*args, **kwargs): + return Field._collect_timeslices(*args, **kwargs) + + @staticmethod + def _collect_timeslices( timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine, netcdf_decodewarning=None ): if netcdf_decodewarning is not None: @@ -471,17 +563,17 @@ def from_netcdf( len(variable) == 2 ), "The variable tuple must have length 2. Use FieldSet.from_netcdf() for multiple variables" - data_filenames = cls.get_dim_filenames(filenames, "data") - lonlat_filename = cls.get_dim_filenames(filenames, "lon") + data_filenames = cls._get_dim_filenames(filenames, "data") + lonlat_filename = cls._get_dim_filenames(filenames, "lon") if isinstance(filenames, dict): assert len(lonlat_filename) == 1 - if lonlat_filename != cls.get_dim_filenames(filenames, "lat"): + if lonlat_filename != cls._get_dim_filenames(filenames, "lat"): raise NotImplementedError( "longitude and latitude dimensions are currently processed together from one single file" ) lonlat_filename = lonlat_filename[0] if "depth" in dimensions: - depth_filename = cls.get_dim_filenames(filenames, "depth") + depth_filename = cls._get_dim_filenames(filenames, "depth") if isinstance(filenames, dict) and len(depth_filename) != 1: raise NotImplementedError("Vertically adaptive meshes not implemented for from_netcdf()") depth_filename = depth_filename[0] @@ -544,7 +636,7 @@ def from_netcdf( if grid is None: # Concatenate time variable to determine overall dimension # across multiple files - time, time_origin, timeslices, dataFiles = cls.collect_timeslices( + time, time_origin, timeslices, dataFiles = cls._collect_timeslices( timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine ) grid = Grid.create_grid(lon, lat, depth, time, time_origin=time_origin, mesh=mesh) @@ -553,7 +645,7 @@ def from_netcdf( elif grid is not None and ("dataFiles" not in kwargs or kwargs["dataFiles"] is None): # ==== means: the field has a shared grid, but may have different data files, so we need to collect the # ==== correct file time series again. - _, _, _, dataFiles = cls.collect_timeslices( + _, _, _, dataFiles = cls._collect_timeslices( timestamps, data_filenames, _grid_fb_class, dimensions, indices, netcdf_engine ) kwargs["dataFiles"] = dataFiles @@ -709,7 +801,11 @@ def from_xarray( **kwargs, ) - def reshape(self, data, transpose=False): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def reshape(self, *args, **kwargs): + return self._reshape(*args, **kwargs) + + def _reshape(self, data, transpose=False): # Ensure that field data is the right data type if not isinstance(data, (np.ndarray, da.core.Array)): data = np.array(data) @@ -807,7 +903,11 @@ def set_depth_from_field(self, field): if self.grid != field.grid: field.grid.depth_field = field + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def calc_cell_edge_sizes(self): + return self._calc_cell_edge_sizes() + + def _calc_cell_edge_sizes(self): """Method to calculate cell sizes based on numpy.gradient method. Currently only works for Rectilinear Grids @@ -823,7 +923,6 @@ def calc_cell_edge_sizes(self): for x, (lon, dx) in enumerate(zip(self.grid.lon, np.gradient(self.grid.lon), strict=False)): self.grid.cell_edge_sizes["x"][y, x] = x_conv.to_source(dx, lon, lat, self.grid.depth[0]) self.grid.cell_edge_sizes["y"][y, x] = y_conv.to_source(dy, lon, lat, self.grid.depth[0]) - self.cell_edge_sizes = self.grid.cell_edge_sizes else: raise ValueError( f"Field.cell_edge_sizes() not implemented for {self.grid.gtype} grids. " @@ -837,10 +936,14 @@ def cell_areas(self): Currently only works for Rectilinear Grids """ if not self.grid.cell_edge_sizes: - self.calc_cell_edge_sizes() + self._calc_cell_edge_sizes() return self.grid.cell_edge_sizes["x"] * self.grid.cell_edge_sizes["y"] + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def search_indices_vertical_z(self, z): + return self._search_indices_vertical_z(z) + + def _search_indices_vertical_z(self, z): grid = self.grid z = np.float32(z) if grid.depth[-1] > grid.depth[0]: @@ -870,7 +973,11 @@ def search_indices_vertical_z(self, z): zeta = (z - grid.depth[zi]) / (grid.depth[zi + 1] - grid.depth[zi]) return (zi, zeta) - def search_indices_vertical_s( + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def search_indices_vertical_s(self, *args, **kwargs): + return self._search_indices_vertical_s(*args, **kwargs) + + def _search_indices_vertical_s( self, x: float, y: float, z: float, xi: int, yi: int, xsi: float, eta: float, ti: int, time: float ): grid = self.grid @@ -929,7 +1036,11 @@ def search_indices_vertical_s( zeta = (z - depth_vector[zi]) / (depth_vector[zi + 1] - depth_vector[zi]) return (zi, zeta) - def reconnect_bnd_indices(self, xi, yi, xdim, ydim, sphere_mesh): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def reconnect_bnd_indices(self, *args, **kwargs): + return self._reconnect_bnd_indices(*args, **kwargs) + + def _reconnect_bnd_indices(self, xi, yi, xdim, ydim, sphere_mesh): if xi < 0: if sphere_mesh: xi = xdim - 2 @@ -948,7 +1059,11 @@ def reconnect_bnd_indices(self, xi, yi, xdim, ydim, sphere_mesh): xi = xdim - xi return xi, yi - def search_indices_rectilinear(self, x: float, y: float, z: float, ti=-1, time=-1, particle=None, search2D=False): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def search_indices_rectilinear(self, *args, **kwargs): + return self._search_indices_rectilinear(*args, **kwargs) + + def _search_indices_rectilinear(self, x: float, y: float, z: float, ti=-1, time=-1, particle=None, search2D=False): grid = self.grid if grid.xdim > 1 and (not grid.zonal_periodic): @@ -1015,13 +1130,13 @@ def search_indices_rectilinear(self, x: float, y: float, z: float, ti=-1, time=- if grid.gtype == GridType.RectilinearZGrid: # Never passes here, because in this case, we work with scipy try: - (zi, zeta) = self.search_indices_vertical_z(z) + (zi, zeta) = self._search_indices_vertical_z(z) except FieldOutOfBoundError: raise FieldOutOfBoundError(x, y, z, field=self) except FieldOutOfBoundSurfaceError: raise FieldOutOfBoundSurfaceError(x, y, z, field=self) elif grid.gtype == GridType.RectilinearSGrid: - (zi, zeta) = self.search_indices_vertical_s(x, y, z, xi, yi, xsi, eta, ti, time) + (zi, zeta) = self._search_indices_vertical_s(x, y, z, xi, yi, xsi, eta, ti, time) else: zi, zeta = -1, 0 @@ -1035,7 +1150,11 @@ def search_indices_rectilinear(self, x: float, y: float, z: float, ti=-1, time=- return (xsi, eta, zeta, xi, yi, zi) - def search_indices_curvilinear(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def search_indices_curvilinear(self, *args, **kwargs): + return self._search_indices_curvilinear(*args, **kwargs) + + def _search_indices_curvilinear(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): if particle: xi = particle.xi[self.igrid] yi = particle.yi[self.igrid] @@ -1094,7 +1213,7 @@ def search_indices_curvilinear(self, x, y, z, ti=-1, time=-1, particle=None, sea yi -= 1 elif eta > 1 + tol: yi += 1 - (xi, yi) = self.reconnect_bnd_indices(xi, yi, grid.xdim, grid.ydim, grid.mesh) + (xi, yi) = self._reconnect_bnd_indices(xi, yi, grid.xdim, grid.ydim, grid.mesh) it += 1 if it > maxIterSearch: print("Correct cell not found after %d iterations" % maxIterSearch) @@ -1107,11 +1226,11 @@ def search_indices_curvilinear(self, x, y, z, ti=-1, time=-1, particle=None, sea if grid.zdim > 1 and not search2D: if grid.gtype == GridType.CurvilinearZGrid: try: - (zi, zeta) = self.search_indices_vertical_z(z) + (zi, zeta) = self._search_indices_vertical_z(z) except FieldOutOfBoundError: raise FieldOutOfBoundError(x, y, z, field=self) elif grid.gtype == GridType.CurvilinearSGrid: - (zi, zeta) = self.search_indices_vertical_s(x, y, z, xi, yi, xsi, eta, ti, time) + (zi, zeta) = self._search_indices_vertical_s(x, y, z, xi, yi, xsi, eta, ti, time) else: zi = -1 zeta = 0 @@ -1126,14 +1245,22 @@ def search_indices_curvilinear(self, x, y, z, ti=-1, time=-1, particle=None, sea return (xsi, eta, zeta, xi, yi, zi) - def search_indices(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def search_indices(self, *args, **kwargs): + return self._search_indices(*args, **kwargs) + + def _search_indices(self, x, y, z, ti=-1, time=-1, particle=None, search2D=False): if self.grid.gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: - return self.search_indices_rectilinear(x, y, z, ti, time, particle=particle, search2D=search2D) + return self._search_indices_rectilinear(x, y, z, ti, time, particle=particle, search2D=search2D) else: - return self.search_indices_curvilinear(x, y, z, ti, time, particle=particle, search2D=search2D) + return self._search_indices_curvilinear(x, y, z, ti, time, particle=particle, search2D=search2D) + + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def interpolator2D(self, *args, **kwargs): + return self._interpolator2D(*args, **kwargs) - def interpolator2D(self, ti, z, y, x, particle=None): - (xsi, eta, _, xi, yi, _) = self.search_indices(x, y, z, particle=particle) + def _interpolator2D(self, ti, z, y, x, particle=None): + (xsi, eta, _, xi, yi, _) = self._search_indices(x, y, z, particle=particle) if self.interp_method == "nearest": xii = xi if xsi <= 0.5 else xi + 1 yii = yi if eta <= 0.5 else yi + 1 @@ -1183,8 +1310,12 @@ def interpolator2D(self, ti, z, y, x, particle=None): else: raise RuntimeError(self.interp_method + " is not implemented for 2D grids") - def interpolator3D(self, ti, z, y, x, time, particle=None): - (xsi, eta, zeta, xi, yi, zi) = self.search_indices(x, y, z, ti, time, particle=particle) + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def interpolator3D(self, *args, **kwargs): + return self._interpolator3D(*args, **kwargs) + + def _interpolator3D(self, ti, z, y, x, time, particle=None): + (xsi, eta, zeta, xi, yi, zi) = self._search_indices(x, y, z, ti, time, particle=particle) if self.interp_method == "nearest": xii = xi if xsi <= 0.5 else xi + 1 yii = yi if eta <= 0.5 else yi + 1 @@ -1293,12 +1424,16 @@ def temporal_interpolate_fullfield(self, ti, time): f1 = self.data[ti + 1, :] return f0 + (f1 - f0) * ((time - t0) / (t1 - t0)) - def spatial_interpolation(self, ti, z, y, x, time, particle=None): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def spatial_interpolation(self, *args, **kwargs): + return self._spatial_interpolation(*args, **kwargs) + + def _spatial_interpolation(self, ti, z, y, x, time, particle=None): """Interpolate horizontal field values using a SciPy interpolator.""" if self.grid.zdim == 1: - val = self.interpolator2D(ti, z, y, x, particle=particle) + val = self._interpolator2D(ti, z, y, x, particle=particle) else: - val = self.interpolator3D(ti, z, y, x, time, particle=particle) + val = self._interpolator3D(ti, z, y, x, time, particle=particle) if np.isnan(val): # Detect Out-of-bounds sampling and raise exception raise FieldOutOfBoundError(x, y, z, field=self) @@ -1307,7 +1442,11 @@ def spatial_interpolation(self, ti, z, y, x, time, particle=None): val = val.compute() return val - def time_index(self, time): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def time_index(self, *args, **kwargs): + return self._time_index(*args, **kwargs) + + def _time_index(self, time): """Find the index in the time array associated with a given time. Note that we normalize to either the first or the last index @@ -1370,11 +1509,11 @@ def eval(self, time, z, y, x, particle=None, applyConversion=True): conversion to the result. Note that we defer to scipy.interpolate to perform spatial interpolation. """ - (ti, periods) = self.time_index(time) + (ti, periods) = self._time_index(time) time -= periods * (self.grid.time_full[-1] - self.grid.time_full[0]) if ti < self.grid.tdim - 1 and time > self.grid.time[ti]: - f0 = self.spatial_interpolation(ti, z, y, x, time, particle=particle) - f1 = self.spatial_interpolation(ti + 1, z, y, x, time, particle=particle) + f0 = self._spatial_interpolation(ti, z, y, x, time, particle=particle) + f1 = self._spatial_interpolation(ti + 1, z, y, x, time, particle=particle) t0 = self.grid.time[ti] t1 = self.grid.time[ti + 1] value = f0 + (f1 - f0) * ((time - t0) / (t1 - t0)) @@ -1382,28 +1521,48 @@ def eval(self, time, z, y, x, particle=None, applyConversion=True): # Skip temporal interpolation if time is outside # of the defined time range or if we have hit an # exact value in the time array. - value = self.spatial_interpolation(ti, z, y, x, self.grid.time[ti], particle=particle) + value = self._spatial_interpolation(ti, z, y, x, self.grid.time[ti], particle=particle) if applyConversion: return self.units.to_target(value, x, y, z) else: return value - def ccode_eval(self, var, t, z, y, x): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def ccode_eval(self, *args, **kwargs): + return self._ccode_eval(*args, **kwargs) + + def _ccode_eval(self, var, t, z, y, x): self._check_velocitysampling() ccode_str = f"temporal_interpolation({x}, {y}, {z}, {t}, {self.ccode_name}, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], &{var}, {self.interp_method.upper()}, {self.gridindexingtype.upper()})" return ccode_str - def ccode_convert(self, _, z, y, x): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def ccode_convert(self, *args, **kwargs): + return self._ccode_convert(*args, **kwargs) + + def _ccode_convert(self, _, z, y, x): return self.units.ccode_to_target(x, y, z) - def get_block_id(self, block): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def get_block_id(self, *args, **kwargs): + return self._get_block_id(*args, **kwargs) + + def _get_block_id(self, block): return np.ravel_multi_index(block, self.nchunks) - def get_block(self, bid): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def get_block(self, *args, **kwargs): + return self._get_block(*args, **kwargs) + + def _get_block(self, bid): return np.unravel_index(bid, self.nchunks[1:]) - def chunk_setup(self): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def chunk_setup(self, *args, **kwargs): + return self._chunk_setup(*args, **kwargs) + + def _chunk_setup(self): if isinstance(self.data, da.core.Array): chunks = self.data.chunks self.nchunks = self.data.numblocks @@ -1420,8 +1579,8 @@ def chunk_setup(self): else: return - self.data_chunks = [None] * npartitions - self.c_data_chunks = [None] * npartitions + self._data_chunks = [None] * npartitions + self._c_data_chunks = [None] * npartitions self.grid.load_chunk = np.zeros(npartitions, dtype=c_int, order="C") # self.grid.chunk_info format: number of dimensions (without tdim); number of chunks per dimensions; # chunksizes (the 0th dim sizes for all chunk of dim[0], then so on for next dims @@ -1431,35 +1590,41 @@ def chunk_setup(self): sum(list(list(ci) for ci in chunks[1:]), []), # noqa: RUF017 # TODO: Perhaps avoid quadratic list summation here ] self.grid.chunk_info = sum(self.grid.chunk_info, []) # noqa: RUF017 - self.chunk_set = True + self._chunk_set = True - def chunk_data(self): - if not self.chunk_set: - self.chunk_setup() + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def chunk_data(self, *args, **kwargs): + return self._chunk_data(*args, **kwargs) + + def _chunk_data(self): + if not self._chunk_set: + self._chunk_setup() g = self.grid if isinstance(self.data, da.core.Array): for block_id in range(len(self.grid.load_chunk)): if ( g.load_chunk[block_id] == g.chunk_loading_requested or g.load_chunk[block_id] in g.chunk_loaded - and self.data_chunks[block_id] is None + and self._data_chunks[block_id] is None ): - block = self.get_block(block_id) - self.data_chunks[block_id] = np.array(self.data.blocks[(slice(self.grid.tdim),) + block], order="C") + block = self._get_block(block_id) + self._data_chunks[block_id] = np.array( + self.data.blocks[(slice(self.grid.tdim),) + block], order="C" + ) elif g.load_chunk[block_id] == g.chunk_not_loaded: - if isinstance(self.data_chunks, list): - self.data_chunks[block_id] = None + if isinstance(self._data_chunks, list): + self._data_chunks[block_id] = None else: - self.data_chunks[block_id, :] = None - self.c_data_chunks[block_id] = None + self._data_chunks[block_id, :] = None + self._c_data_chunks[block_id] = None else: - if isinstance(self.data_chunks, list): - self.data_chunks[0] = None + if isinstance(self._data_chunks, list): + self._data_chunks[0] = None else: - self.data_chunks[0, :] = None - self.c_data_chunks[0] = None + self._data_chunks[0, :] = None + self._c_data_chunks[0] = None self.grid.load_chunk[0] = g.chunk_loaded_touched - self.data_chunks[0] = np.array(self.data, order="C") + self._data_chunks[0] = np.array(self.data, order="C") @property def ctypes_struct(self): @@ -1488,11 +1653,11 @@ class CField(Structure): "data_chunks should have been loaded by now if requested. grid.load_chunk[bid] cannot be 1" ) if self.grid.load_chunk[i] in self.grid.chunk_loaded: - if not self.data_chunks[i].flags["C_CONTIGUOUS"]: - self.data_chunks[i] = np.array(self.data_chunks[i], order="C") - self.c_data_chunks[i] = self.data_chunks[i].ctypes.data_as(POINTER(POINTER(c_float))) + if not self._data_chunks[i].flags["C_CONTIGUOUS"]: + self._data_chunks[i] = np.array(self._data_chunks[i], order="C") + self._c_data_chunks[i] = self._data_chunks[i].ctypes.data_as(POINTER(POINTER(c_float))) else: - self.c_data_chunks[i] = None + self._c_data_chunks[i] = None cstruct = CField( self.grid.xdim, @@ -1502,7 +1667,7 @@ class CField(Structure): self.igrid, allow_time_extrapolation, time_periodic, - (POINTER(POINTER(c_float)) * len(self.c_data_chunks))(*self.c_data_chunks), + (POINTER(POINTER(c_float)) * len(self._c_data_chunks))(*self._c_data_chunks), pointer(self.grid.ctypes_struct), ) return cstruct @@ -1542,8 +1707,6 @@ def add_periodic_halo(self, zonal, meridional, halosize=5, data=None): (data[:, :, :, -halosize:], data, data[:, :, :, 0:halosize]), axis=len(data.shape) - 1 ) assert data.shape[3] == self.grid.xdim, "Fourth dim must be x." - self.lon = self.grid.lon - self.lat = self.grid.lat if meridional: if len(data.shape) == 3: data = lib.concatenate((data[:, -halosize:, :], data, data[:, 0:halosize, :]), axis=len(data.shape) - 2) @@ -1553,7 +1716,6 @@ def add_periodic_halo(self, zonal, meridional, halosize=5, data=None): (data[:, :, -halosize:, :], data, data[:, :, 0:halosize, :]), axis=len(data.shape) - 2 ) assert data.shape[2] == self.grid.ydim, "Third dim must be y." - self.lat = self.grid.lat if dataNone: self.data = data else: @@ -1606,7 +1768,11 @@ def write(self, filename, varname=None): ) dset.to_netcdf(filepath, unlimited_dims="time_counter") - def rescale_and_set_minmax(self, data): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def rescale_and_set_minmax(self, *args, **kwargs): + return self._rescale_and_set_minmax(*args, **kwargs) + + def _rescale_and_set_minmax(self, data): data[np.isnan(data)] = 0 if self._scaling_factor: data *= self._scaling_factor @@ -1616,7 +1782,11 @@ def rescale_and_set_minmax(self, data): data[data > self.vmax] = 0 return data - def data_concatenate(self, data, data_to_concat, tindex): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def data_concatenate(self, *args, **kwargs): + return self._data_concatenate(*args, **kwargs) + + def _data_concatenate(self, data, data_to_concat, tindex): if data[tindex] is not None: if isinstance(data, np.ndarray): data[tindex] = None @@ -1642,9 +1812,9 @@ def computeTimeChunk(self, data, tindex): ti = g.ti + tindex timestamp = self.timestamps[np.where(ti < summedlen)[0][0]] - rechunk_callback_fields = self.chunk_setup if isinstance(tindex, list) else None + rechunk_callback_fields = self._chunk_setup if isinstance(tindex, list) else None filebuffer = self._field_fb_class( - self.dataFiles[g.ti + tindex], + self._dataFiles[g.ti + tindex], self.dimensions, self.indices, netcdf_engine=self.netcdf_engine, @@ -1682,7 +1852,7 @@ def computeTimeChunk(self, data, tindex): (), ), ) - data = self.data_concatenate(data, buffer_data, tindex) + data = self._data_concatenate(data, buffer_data, tindex) self.filebuffers[tindex] = filebuffer return data @@ -1747,7 +1917,7 @@ def jacobian(self, xsi: float, eta: float, px: np.ndarray, py: np.ndarray): def spatial_c_grid_interpolation2D(self, ti, z, y, x, time, particle=None, applyConversion=True): grid = self.U.grid - (xsi, eta, zeta, xi, yi, zi) = self.U.search_indices(x, y, z, ti, time, particle=particle) + (xsi, eta, zeta, xi, yi, zi) = self.U._search_indices(x, y, z, ti, time, particle=particle) if grid.gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]]) @@ -1819,7 +1989,7 @@ def spatial_c_grid_interpolation2D(self, ti, z, y, x, time, particle=None, apply def spatial_c_grid_interpolation3D_full(self, ti, z, y, x, time, particle=None): grid = self.U.grid - (xsi, eta, zet, xi, yi, zi) = self.U.search_indices(x, y, z, ti, time, particle=particle) + (xsi, eta, zet, xi, yi, zi) = self.U._search_indices(x, y, z, ti, time, particle=particle) if grid.gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]]) @@ -2060,7 +2230,7 @@ def _is_land2D(self, di, yi, xi): return True def spatial_slip_interpolation(self, ti, z, y, x, time, particle=None, applyConversion=True): - (xsi, eta, zeta, xi, yi, zi) = self.U.search_indices(x, y, z, ti, time, particle=particle) + (xsi, eta, zeta, xi, yi, zi) = self.U._search_indices(x, y, z, ti, time, particle=particle) di = ti if self.U.grid.zdim == 1 else zi # general third dimension f_u, f_v, f_w = 1, 1, 1 @@ -2184,7 +2354,7 @@ def eval(self, time, z, y, x, particle=None, applyConversion=True): "freeslip": {"2D": self.spatial_slip_interpolation, "3D": self.spatial_slip_interpolation}, } grid = self.U.grid - (ti, periods) = self.U.time_index(time) + (ti, periods) = self.U._time_index(time) time -= periods * (grid.time_full[-1] - grid.time_full[0]) if ti < grid.tdim - 1 and time > grid.time[ti]: t0 = grid.time[ti] @@ -2232,7 +2402,11 @@ def __getitem__(self, key): except tuple(AllParcelsErrorCodes.keys()) as error: return _deal_with_errors(error, key, vector_type=self.vector_type) - def ccode_eval(self, varU, varV, varW, U, V, W, t, z, y, x): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def ccode_eval(self, *args, **kwargs): + return self._ccode_eval(*args, **kwargs) + + def _ccode_eval(self, varU, varV, varW, U, V, W, t, z, y, x): ccode_str = "" if self.vector_type == "3D": ccode_str = ( diff --git a/parcels/fieldset.py b/parcels/fieldset.py index c21c1f79a..db99185b5 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -14,6 +14,7 @@ from parcels.grid import Grid from parcels.gridset import GridSet from parcels.particlefile import ParticleFile +from parcels.tools._helpers import deprecated_made_private from parcels.tools.converters import TimeConverter, convert_xarray_time_units from parcels.tools.loggers import logger from parcels.tools.statuscodes import TimeExtrapolationError @@ -38,8 +39,8 @@ class FieldSet: def __init__(self, U: Field | NestedField | None, V: Field | NestedField | None, fields=None): self.gridset = GridSet() - self.completed: bool = False - self.particlefile: ParticleFile | None = None + self._completed: bool = False + self._particlefile: ParticleFile | None = None if U: self.add_field(U, "U") # see #1663 for type-ignore reason @@ -53,7 +54,16 @@ def __init__(self, U: Field | NestedField | None, V: Field | NestedField | None, self.add_field(field, name) self.compute_on_defer = None - self.add_UVfield() + self._add_UVfield() + + @property + def particlefile(self): + return self._particlefile + + @property + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def completed(self): + return self._completed @staticmethod def checkvaliddimensionsdict(dims): @@ -180,7 +190,7 @@ def add_field(self, field: Field | NestedField, name: str | None = None): * `Unit converters <../examples/tutorial_unitconverters.ipynb>`__ (Default value = None) """ - if self.completed: + if self._completed: raise RuntimeError( "FieldSet has already been completed. Are you trying to add a Field after you've created the ParticleSet?" ) @@ -235,7 +245,11 @@ def add_vector_field(self, vfield): for f in vfield: f.fieldset = self - def add_UVfield(self): + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def add_UVfield(self, *args, **kwargs): + return self._add_UVfield(*args, **kwargs) + + def _add_UVfield(self): if not hasattr(self, "UV") and hasattr(self, "U") and hasattr(self, "V"): if isinstance(self.U, NestedField): self.add_vector_field(NestedField("UV", self.U, self.V)) @@ -247,7 +261,11 @@ def add_UVfield(self): else: self.add_vector_field(VectorField("UVW", self.U, self.V, self.W)) + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 def check_complete(self): + return self._check_complete() + + def _check_complete(self): assert self.U, 'FieldSet does not have a Field named "U"' assert self.V, 'FieldSet does not have a Field named "V"' for attr, value in vars(self).items(): @@ -299,7 +317,7 @@ def check_velocityfields(U, V, W): if g.defer_load: g.time_full = g.time_full + self.time_origin.reltime(g.time_origin) g.time_origin = self.time_origin - self.add_UVfield() + self._add_UVfield() ccode_fieldnames = [] counter = 1 @@ -312,7 +330,7 @@ def check_velocityfields(U, V, W): ccode_fieldnames.append(fld.ccode_name) for f in self.get_fields(): - if isinstance(f, (VectorField, NestedField)) or f.dataFiles is None: + if isinstance(f, (VectorField, NestedField)) or f._dataFiles is None: continue if f.grid.depth_field is not None: if f.grid.depth_field == "not_yet_set": @@ -322,10 +340,15 @@ def check_velocityfields(U, V, W): if not f.grid.defer_load: depth_data = f.grid.depth_field.data f.grid.depth = depth_data if isinstance(depth_data, np.ndarray) else np.array(depth_data) - self.completed = True + self._completed = True + + @classmethod + @deprecated_made_private # TODO: Remove 6 months after v3.1.0 + def parse_wildcards(self, *args, **kwargs): + return self._parse_wildcards(*args, **kwargs) @classmethod - def parse_wildcards(cls, paths, filenames, var): + def _parse_wildcards(cls, paths, filenames, var): if not isinstance(paths, list): paths = sorted(glob(str(paths))) if len(paths) == 0: @@ -451,10 +474,10 @@ def from_netcdf( # Resolve all matching paths for the current variable paths = filenames[var] if type(filenames) is dict and var in filenames else filenames if type(paths) is not dict: - paths = cls.parse_wildcards(paths, filenames, var) + paths = cls._parse_wildcards(paths, filenames, var) else: for dim, p in paths.items(): - paths[dim] = cls.parse_wildcards(p, filenames, var) + paths[dim] = cls._parse_wildcards(p, filenames, var) # Use dimensions[var] and indices[var] if either of them is a dict of dicts dims = dimensions[var] if var in dimensions else dimensions @@ -497,7 +520,7 @@ def from_netcdf( if processedGrid: grid = fields[procvar].grid if procpaths == nowpaths: - dFiles = fields[procvar].dataFiles + dFiles = fields[procvar]._dataFiles break fields[var] = Field.from_netcdf( paths, @@ -1332,7 +1355,7 @@ def from_modulefile(cls, filename, modulename="create_fieldset", **kwargs): raise OSError(f"Module {filename}.{modulename} does not return a FieldSet object") return fieldset - def get_fields(self): + def get_fields(self) -> list[Field | VectorField]: """Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField` objects associated with this FieldSet. """ @@ -1442,7 +1465,7 @@ def computeTimeChunk(self, time=0.0, dt=1): nextTime = min(nextTime, nextTime_loc) if signdt >= 0 else max(nextTime, nextTime_loc) for f in self.get_fields(): - if isinstance(f, (VectorField, NestedField)) or not f.grid.defer_load or f.dataFiles is None: + if isinstance(f, (VectorField, NestedField)) or not f.grid.defer_load or f._dataFiles is None: continue g = f.grid if g.update_status == "first_updated": # First load of data @@ -1461,20 +1484,20 @@ def computeTimeChunk(self, time=0.0, dt=1): data = lib.empty( (g.tdim, zd, g.ydim - 2 * g.meridional_halo, g.xdim - 2 * g.zonal_halo), dtype=np.float32 ) - f.loaded_time_indices = range(2) - for tind in f.loaded_time_indices: + f._loaded_time_indices = range(2) + for tind in f._loaded_time_indices: for fb in f.filebuffers: if fb is not None: fb.close() fb = None data = f.computeTimeChunk(data, tind) - data = f.rescale_and_set_minmax(data) + data = f._rescale_and_set_minmax(data) if isinstance(f.data, DeferredArray): f.data = DeferredArray() - f.data = f.reshape(data) - if not f.chunk_set: - f.chunk_setup() + f.data = f._reshape(data) + if not f._chunk_set: + f._chunk_setup() if len(g.load_chunk) > g.chunk_not_loaded: g.load_chunk = np.where( g.load_chunk == g.chunk_loaded_touched, g.chunk_loading_requested, g.load_chunk @@ -1491,22 +1514,22 @@ def computeTimeChunk(self, time=0.0, dt=1): (g.tdim, zd, g.ydim - 2 * g.meridional_halo, g.xdim - 2 * g.zonal_halo), dtype=np.float32 ) if signdt >= 0: - f.loaded_time_indices = [1] + f._loaded_time_indices = [1] if f.filebuffers[0] is not None: f.filebuffers[0].close() f.filebuffers[0] = None f.filebuffers[0] = f.filebuffers[1] data = f.computeTimeChunk(data, 1) else: - f.loaded_time_indices = [0] + f._loaded_time_indices = [0] if f.filebuffers[1] is not None: f.filebuffers[1].close() f.filebuffers[1] = None f.filebuffers[1] = f.filebuffers[0] data = f.computeTimeChunk(data, 0) - data = f.rescale_and_set_minmax(data) + data = f._rescale_and_set_minmax(data) if signdt >= 0: - data = f.reshape(data)[1, :] + data = f._reshape(data)[1, :] if lib is da: f.data = lib.stack([f.data[1, :], data], axis=0) else: @@ -1518,7 +1541,7 @@ def computeTimeChunk(self, time=0.0, dt=1): f.data[0, :] = f.data[1, :] f.data[1, :] = data else: - data = f.reshape(data)[0, :] + data = f._reshape(data)[0, :] if lib is da: f.data = lib.stack([data, f.data[0, :]], axis=0) else: @@ -1535,30 +1558,30 @@ def computeTimeChunk(self, time=0.0, dt=1): if signdt >= 0: for block_id in range(len(g.load_chunk)): if g.load_chunk[block_id] == g.chunk_loaded_touched: - if f.data_chunks[block_id] is None: + if f._data_chunks[block_id] is None: # file chunks were never loaded. # happens when field not called by kernel, but shares a grid with another field called by kernel break block = f.get_block(block_id) - f.data_chunks[block_id][0] = None - f.data_chunks[block_id][1] = np.array(f.data.blocks[(slice(2),) + block][1]) + f._data_chunks[block_id][0] = None + f._data_chunks[block_id][1] = np.array(f.data.blocks[(slice(2),) + block][1]) else: for block_id in range(len(g.load_chunk)): if g.load_chunk[block_id] == g.chunk_loaded_touched: - if f.data_chunks[block_id] is None: + if f._data_chunks[block_id] is None: # file chunks were never loaded. # happens when field not called by kernel, but shares a grid with another field called by kernel break block = f.get_block(block_id) - f.data_chunks[block_id][1] = None - f.data_chunks[block_id][0] = np.array(f.data.blocks[(slice(2),) + block][0]) + f._data_chunks[block_id][1] = None + f._data_chunks[block_id][0] = np.array(f.data.blocks[(slice(2),) + block][0]) # do user-defined computations on fieldset data if self.compute_on_defer: self.compute_on_defer(self) # update time varying grid depth for f in self.get_fields(): - if isinstance(f, (VectorField, NestedField)) or not f.grid.defer_load or f.dataFiles is None: + if isinstance(f, (VectorField, NestedField)) or not f.grid.defer_load or f._dataFiles is None: continue if f.grid.depth_field is not None: depth_data = f.grid.depth_field.data diff --git a/parcels/grid.py b/parcels/grid.py index 70a277be9..955abd9f4 100644 --- a/parcels/grid.py +++ b/parcels/grid.py @@ -6,7 +6,7 @@ import numpy as np import numpy.typing as npt -from parcels._typing import Mesh +from parcels._typing import Mesh, UpdateStatus, assert_valid_mesh from parcels.tools.converters import TimeConverter from parcels.tools.warnings import FieldSetWarning @@ -54,6 +54,7 @@ def __init__( self.zi = None self.ti = -1 self.lon = lon + self.update_status: UpdateStatus | None = None if not self.lon.flags["C_CONTIGUOUS"]: self.lon = np.array(self.lon, order="C") self.lat = lat @@ -74,6 +75,7 @@ def __init__( self.time_full = self.time # needed for deferred_loaded Fields self.time_origin = TimeConverter() if time_origin is None else time_origin assert isinstance(self.time_origin, TimeConverter), "time_origin needs to be a TimeConverter object" + assert_valid_mesh(mesh) self.mesh = mesh self.cstruct = None self.cell_edge_sizes: dict[str, npt.NDArray] = {} @@ -276,7 +278,7 @@ def computeTimeChunk(self, f, time, signdt): self.update_status = "updated" if self.ti == -1: self.time = self.time_full - self.ti, _ = f.time_index(time) + self.ti, _ = f._time_index(time) periods = self.periods.value if isinstance(self.periods, c_int) else self.periods if ( signdt == -1 diff --git a/parcels/gridset.py b/parcels/gridset.py index baf3e2dca..e3b619110 100644 --- a/parcels/gridset.py +++ b/parcels/gridset.py @@ -36,7 +36,7 @@ def add_grid(self, field): if sameGrid: existing_grid = True - field.grid = g + field._grid = g # TODO: Is this even necessary? break if not existing_grid: diff --git a/parcels/kernel.py b/parcels/kernel.py index c20bcc264..442e6e4d7 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -344,14 +344,14 @@ def check_fieldsets_in_kernels(self, pyfunc): warning = False if ( isinstance(self._fieldset.W, Field) - and self._fieldset.W.creation_log != "from_nemo" + and self._fieldset.W._creation_log != "from_nemo" and self._fieldset.W._scaling_factor is not None and self._fieldset.W._scaling_factor > 0 ): warning = True if isinstance(self._fieldset.W, NestedField): for f in self._fieldset.W: - if f.creation_log != "from_nemo" and f._scaling_factor is not None and f._scaling_factor > 0: + if f._creation_log != "from_nemo" and f._scaling_factor is not None and f._scaling_factor > 0: warning = True if warning: warnings.warn( @@ -585,11 +585,11 @@ def load_fieldset_jit(self, pset): if f.data.dtype != np.float32: raise RuntimeError(f"Field {f.name} data needs to be float32 in JIT mode") if f in self.field_args.values(): - f.chunk_data() + f._chunk_data() else: - for block_id in range(len(f.data_chunks)): - f.data_chunks[block_id] = None - f.c_data_chunks[block_id] = None + for block_id in range(len(f._data_chunks)): + f._data_chunks[block_id] = None + f._c_data_chunks[block_id] = None for g in pset.fieldset.gridset.grids: g.load_chunk = np.where(g.load_chunk == g.chunk_loading_requested, g.chunk_loaded_touched, g.load_chunk) diff --git a/parcels/particlefile.py b/parcels/particlefile.py index db24bab1c..24bb178a8 100644 --- a/parcels/particlefile.py +++ b/parcels/particlefile.py @@ -69,7 +69,7 @@ def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_z if var.to_write: self.vars_to_write[var.name] = var.dtype self.mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 - self.particleset.fieldset.particlefile = self + self.particleset.fieldset._particlefile = self self.analytical = False # Flag to indicate if ParticleFile is used for analytical trajectories # Reset obs_written of each particle, in case new ParticleFile created for a ParticleSet diff --git a/parcels/particleset.py b/parcels/particleset.py index 11007c93d..a7dc0602d 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -108,7 +108,7 @@ def __init__( self.interaction_kernel = None self.fieldset = fieldset - self.fieldset.check_complete() + self.fieldset._check_complete() self.time_origin = fieldset.time_origin # ==== first: create a new subclass of the pclass that includes the required variables ==== # diff --git a/parcels/tools/_helpers.py b/parcels/tools/_helpers.py index 00893c040..7f698b719 100644 --- a/parcels/tools/_helpers.py +++ b/parcels/tools/_helpers.py @@ -37,7 +37,7 @@ def wrapper(*args, **kwargs): f"`{func.__qualname__}` is deprecated and will be removed in a future release of {PACKAGE}.{msg}" ) - warnings.warn(msg_formatted, category=DeprecationWarning, stacklevel=2) + warnings.warn(msg_formatted, category=DeprecationWarning, stacklevel=3) return func(*args, **kwargs) return wrapper @@ -49,5 +49,5 @@ def deprecated_made_private(func: Callable) -> Callable: return deprecated( "It has moved to the internal API as it is not expected to be directly used by " "the end-user. If you feel that you use this code directly in your scripts, please " - "comment on our tracking issue at <>." # TODO: Add tracking issue + "comment on our tracking issue at https://github.com/OceanParcels/Parcels/issues/1695.", )(func) diff --git a/pyproject.toml b/pyproject.toml index 0d5941351..0487fe167 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ local_scheme = "no-local-version" [tool.pytest.ini_options] python_files = ["test_*.py", "example_*.py", "*tutorial*"] +filterwarnings = [ + "error:.*removed in a future release of Parcels.*:DeprecationWarning", # Have Parcels DeprecationWarnings fail CI (prevents deprecated items being used in internal code) + ] [tool.ruff] line-length = 120 diff --git a/tests/test_advection.py b/tests/test_advection.py index d2c283b66..b62a84370 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -71,7 +71,7 @@ def test_advection_zonal(lon, lat, depth, mode): } dimensions = {"lon": lon, "lat": lat} fieldset2D = FieldSet.from_data(data2D, dimensions, mesh="spherical", transpose=True) - assert fieldset2D.U.creation_log == "from_data" + assert fieldset2D.U._creation_log == "from_data" pset2D = ParticleSet(fieldset2D, pclass=ptype[mode], lon=np.zeros(npart) + 20.0, lat=np.linspace(0, 80, npart)) pset2D.execute(AdvectionRK4, runtime=timedelta(hours=2), dt=timedelta(seconds=30)) diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py new file mode 100644 index 000000000..2704df25c --- /dev/null +++ b/tests/test_deprecations.py @@ -0,0 +1,112 @@ +import inspect + +import pytest + +from parcels import Field, FieldSet +from tests.utils import create_fieldset_unit_mesh + +fieldset = create_fieldset_unit_mesh() +field = fieldset.U + +private_field_attrs = [ + "_dataFiles", + "_loaded_time_indices", + "_creation_log", + "_data_chunks", + "_c_data_chunks", + "_chunk_set", +] + + +class FieldPrivate: + attributes = [ + "_dataFiles", + "_loaded_time_indices", + "_creation_log", + "_data_chunks", + "_c_data_chunks", + "_chunk_set", + ] + methods = [ + "_get_dim_filenames", + "_collect_timeslices", + "_reshape", + "_calc_cell_edge_sizes", + "_search_indices_vertical_z", + "_search_indices_vertical_s", + "_reconnect_bnd_indices", + "_search_indices_rectilinear", + "_search_indices_curvilinear", + "_search_indices", + "_interpolator2D", + "_interpolator3D", + "_ccode_eval", + "_ccode_convert", + "_get_block_id", + "_get_block", + "_chunk_setup", + "_chunk_data", + "_rescale_and_set_minmax", + "_data_concatenate", + "_spatial_interpolation", + "_time_index", + ] + + +class FieldSetPrivate: + attributes = [ + "_completed", + ] + methods = [ + "_add_UVfield", + "_parse_wildcards", + "_check_complete", + ] + + +def assert_private_public_attribute_equiv(obj, private_attribute: str): + assert private_attribute.startswith("_") + attribute = private_attribute.lstrip("_") + + with pytest.raises(DeprecationWarning): + assert hasattr(obj, attribute) + assert hasattr(obj, private_attribute) + assert getattr(obj, attribute) is getattr(obj, private_attribute) + + +def assert_public_method_calls_private(type_, private_method): + """Looks at the source code to ensure that `public_method` calls `private_method`. + + Looks for the string `.{method_name}(` in the source code of `public_method`. + """ + assert private_method.startswith("_") + public_method_str = private_method.lstrip("_") + private_method_str = private_method + + public_method = getattr(type_, public_method_str) + private_method = getattr(type_, private_method_str) + + assert callable(public_method) + assert callable(private_method) + + assert f".{private_method_str}(" in inspect.getsource(public_method) + + +@pytest.mark.parametrize("private_attribute", FieldPrivate.attributes) +def test_private_attribute_field(private_attribute): + assert_private_public_attribute_equiv(field, private_attribute) + + +@pytest.mark.parametrize("private_attribute", FieldSetPrivate.attributes) +def test_private_attribute_fieldset(private_attribute): + assert_private_public_attribute_equiv(fieldset, private_attribute) + + +@pytest.mark.parametrize("private_method", FieldPrivate.methods) +def test_private_method_field(private_method): + assert_public_method_calls_private(Field, private_method) + + +@pytest.mark.parametrize("private_method", FieldSetPrivate.methods) +def test_private_method_fieldset(private_method): + assert_public_method_calls_private(FieldSet, private_method) diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index b5e9c1408..5c7a6e073 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -203,7 +203,7 @@ def test_fieldset_from_modulefile(): nemo_error_fname = str(TEST_DATA / "fieldset_nemo_error.py") fieldset = FieldSet.from_modulefile(nemo_fname) - assert fieldset.U.creation_log == "from_nemo" + assert fieldset.U._creation_log == "from_nemo" indices = {"lon": range(6, 10)} fieldset = FieldSet.from_modulefile(nemo_fname, indices=indices) @@ -379,7 +379,7 @@ def test_add_duplicate_field(dupobject): def test_add_field_after_pset(fieldtype): data, dimensions = generate_fieldset_data(100, 100) fieldset = FieldSet.from_data(data, dimensions) - pset = ParticleSet(fieldset, ScipyParticle, lon=0, lat=0) # noqa ; to trigger fieldset.check_complete + pset = ParticleSet(fieldset, ScipyParticle, lon=0, lat=0) # noqa ; to trigger fieldset._check_complete field1 = Field("field1", fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat) field2 = Field("field2", fieldset.U.data, lon=fieldset.U.lon, lat=fieldset.U.lat) vfield = VectorField("vfield", field1, field2) @@ -425,7 +425,7 @@ def test_fieldset_dimlength1_cgrid(gridtype): fieldset.U.interp_method = "cgrid_velocity" fieldset.V.interp_method = "cgrid_velocity" try: - fieldset.check_complete() + fieldset._check_complete() success = True if gridtype == "A" else False except NotImplementedError: success = True if gridtype == "C" else False @@ -507,7 +507,7 @@ def test_fieldset_celledgesizes(mesh): data, dimensions = generate_fieldset_data(10, 7) fieldset = FieldSet.from_data(data, dimensions, mesh=mesh) - fieldset.U.calc_cell_edge_sizes() + fieldset.U._calc_cell_edge_sizes() D_meridional = fieldset.U.cell_edge_sizes["y"] D_zonal = fieldset.U.cell_edge_sizes["x"] @@ -542,7 +542,7 @@ def test_fieldset_write_curvilinear(tmpdir): variables = {"dx": "e1u"} dimensions = {"lon": "glamu", "lat": "gphiu"} fieldset = FieldSet.from_nemo(filenames, variables, dimensions) - assert fieldset.dx.creation_log == "from_nemo" + assert fieldset.dx._creation_log == "from_nemo" newfile = tmpdir.join("curv_field") fieldset.write(newfile) @@ -552,7 +552,7 @@ def test_fieldset_write_curvilinear(tmpdir): variables={"dx": "dx"}, dimensions={"time": "time_counter", "depth": "depthdx", "lon": "nav_lon", "lat": "nav_lat"}, ) - assert fieldset2.dx.creation_log == "from_netcdf" + assert fieldset2.dx._creation_log == "from_netcdf" for var in ["lon", "lat", "data"]: assert np.allclose(getattr(fieldset2.dx, var), getattr(fieldset.dx, var)) @@ -949,7 +949,7 @@ def test_fieldset_defer_loading_with_diff_time_origin(tmpdir, fail): fieldset_out.add_field(fieldW) fieldset_out.write(filepath) fieldset = FieldSet.from_parcels(filepath, extra_fields={"W": "W"}) - assert fieldset.U.creation_log == "from_parcels" + assert fieldset.U._creation_log == "from_parcels" pset = ParticleSet.from_list( fieldset, pclass=JITParticle, lon=[0.5], lat=[0.5], depth=[0.5], time=[datetime.datetime(2018, 4, 20, 1)] ) @@ -984,11 +984,12 @@ def test_fieldset_defer_loading_function(zdim, scale_fac, tmpdir): def compute(fieldset): # Calculating vertical weighted average + f: Field for f in [fieldset.U, fieldset.V]: - for tind in f.loaded_time_indices: + for tind in f._loaded_time_indices: data = da.sum(f.data[tind, :] * DZ, axis=0) / sum(dz) data = da.broadcast_to(data, (1, f.grid.zdim, f.grid.ydim, f.grid.xdim)) - f.data = f.data_concatenate(f.data, data, tind) + f.data = f._data_concatenate(f.data, data, tind) fieldset.compute_on_defer = compute fieldset.computeTimeChunk(1, 1) @@ -1070,7 +1071,7 @@ def generate_dataset(xdim, ydim, zdim=1, tdim=1): else: dimensions = {"lat": "lat", "lon": "lon", "depth": "depth"} fieldset = FieldSet.from_xarray_dataset(ds, variables, dimensions, mesh="flat") - assert fieldset.U.creation_log == "from_xarray_dataset" + assert fieldset.U._creation_log == "from_xarray_dataset" pset = ParticleSet(fieldset, JITParticle, 0, 0, depth=20) diff --git a/tests/test_fieldset_sampling.py b/tests/test_fieldset_sampling.py index 9262b7ccb..e95b89c92 100644 --- a/tests/test_fieldset_sampling.py +++ b/tests/test_fieldset_sampling.py @@ -261,7 +261,7 @@ def test_inversedistance_nearland(mode, withDepth, arrtype): success = False try: fieldset.U.interp_method = "linear_invdist_land_tracer" - fieldset.check_complete() + fieldset._check_complete() except NotImplementedError: success = True assert success @@ -784,7 +784,7 @@ def test_multiple_grid_addlater_error(): ) fieldset = FieldSet(U, V) - pset = ParticleSet(fieldset, pclass=pclass("jit"), lon=[0.8], lat=[0.9]) # noqa ; to trigger fieldset.check_complete + pset = ParticleSet(fieldset, pclass=pclass("jit"), lon=[0.8], lat=[0.9]) # noqa ; to trigger fieldset._check_complete P = Field( "P", diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 000000000..f4efba3ef --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,31 @@ +import pytest + +from parcels._typing import ( + assert_valid_gridindexingtype, + assert_valid_interp_method, + assert_valid_mesh, +) + +validators = ( + assert_valid_interp_method, + assert_valid_mesh, + assert_valid_gridindexingtype, +) + + +@pytest.mark.parametrize("validator", validators) +def test_invalid_option(validator): + with pytest.raises(ValueError): + validator("invalid option") + + +validation_mapping = [ + (assert_valid_interp_method, "nearest"), + (assert_valid_mesh, "spherical"), + (assert_valid_gridindexingtype, "pop"), +] + + +@pytest.mark.parametrize("validator, value", validation_mapping) +def test_valid_option(validator, value): + validator(value)