Skip to content

Commit

Permalink
API changes: Field and FieldSet (#1709)
Browse files Browse the repository at this point in the history
Make selected Field and FieldSet attributes and methods private or read-only.
  • Loading branch information
VeckoTheGecko authored Sep 25, 2024
1 parent bd6a09a commit 018be97
Show file tree
Hide file tree
Showing 18 changed files with 544 additions and 175 deletions.
2 changes: 1 addition & 1 deletion docs/examples/example_nemo_curvilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_nemo_3D_samegrid():

fieldset = parcels.FieldSet.from_nemo(filenames, variables, dimensions)

assert fieldset.U.dataFiles is not fieldset.W.dataFiles
assert fieldset.U._dataFiles is not fieldset.W._dataFiles


def main(args=None):
Expand Down
27 changes: 25 additions & 2 deletions parcels/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import datetime
import os
from collections.abc import Callable
from typing import Literal
from typing import Any, Literal, get_args


class ParcelsAST(ast.AST):
Expand All @@ -37,10 +37,33 @@ class ParcelsAST(ast.AST):
Mesh = Literal["spherical", "flat"] # corresponds with `mesh`
VectorType = Literal["3D", "2D"] | None # corresponds with `vector_type`
ChunkMode = Literal["auto", "specific", "failsafe"] # corresponds with `chunk_mode`
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `grid_indexing_type`
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `gridindexingtype`
UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `update_status`
TimePeriodic = float | datetime.timedelta | Literal[False] # corresponds with `update_status`
NetcdfEngine = Literal["netcdf4", "xarray", "scipy"]


KernelFunction = Callable[..., None]


def _validate_against_pure_literal(value, typing_literal):
"""Uses a Literal type alias to validate.
Can't be used with ``Literal[...] | None`` etc. as its not a pure literal.
"""
if value not in get_args(typing_literal):
msg = f"Invalid value {value!r}. Valid options are {get_args(typing_literal)!r}"
raise ValueError(msg)


# Assertion functions to clean user input
def assert_valid_interp_method(value: Any):
_validate_against_pure_literal(value, InterpMethodOption)


def assert_valid_mesh(value: Any):
_validate_against_pure_literal(value, Mesh)


def assert_valid_gridindexingtype(value: Any):
_validate_against_pure_literal(value, GridIndexingType)
4 changes: 2 additions & 2 deletions parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ def AdvectionAnalytical(particle, fieldset, time):
direction = 1.0 if particle.dt > 0 else -1.0
withW = True if "W" in [f.name for f in fieldset.get_fields()] else False
withTime = True if len(fieldset.U.grid.time_full) > 1 else False
ti = fieldset.U.time_index(time)[0]
ti = fieldset.U._time_index(time)[0]
ds_t = particle.dt
if withTime:
tau = (time - fieldset.U.grid.time[ti]) / (fieldset.U.grid.time[ti + 1] - fieldset.U.grid.time[ti])
time_i = np.linspace(0, fieldset.U.grid.time[ti + 1] - fieldset.U.grid.time[ti], I_s)
ds_t = min(ds_t, time_i[np.where(time - fieldset.U.grid.time[ti] < time_i)[0][0]])

xsi, eta, zeta, xi, yi, zi = fieldset.U.search_indices(
xsi, eta, zeta, xi, yi, zi = fieldset.U._search_indices(
particle.lon, particle.lat, particle.depth, particle=particle
)
if withW:
Expand Down
24 changes: 12 additions & 12 deletions parcels/compilation/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,14 +819,14 @@ def visit_FieldEvalNode(self, node):
self.visit(node.field)
self.visit(node.args)
args = self._check_FieldSamplingArguments(node.args.ccode)
ccode_eval = node.field.obj.ccode_eval(node.var, *args)
ccode_eval = node.field.obj._ccode_eval(node.var, *args)
stmts = [
c.Assign("parcels_interp_state", ccode_eval),
c.Assign("particles->state[pnum]", "max(particles->state[pnum], parcels_interp_state)"),
]

if node.convert:
ccode_conv = node.field.obj.ccode_convert(*args)
ccode_conv = node.field.obj._ccode_convert(*args)
conv_stat = c.Statement(f"{node.var} *= {ccode_conv}")
stmts += [conv_stat]

Expand All @@ -836,17 +836,17 @@ def visit_VectorFieldEvalNode(self, node):
self.visit(node.field)
self.visit(node.args)
args = self._check_FieldSamplingArguments(node.args.ccode)
ccode_eval = node.field.obj.ccode_eval(
ccode_eval = node.field.obj._ccode_eval(
node.var, node.var2, node.var3, node.field.obj.U, node.field.obj.V, node.field.obj.W, *args
)
if node.convert and node.field.obj.U.interp_method != "cgrid_velocity":
ccode_conv1 = node.field.obj.U.ccode_convert(*args)
ccode_conv2 = node.field.obj.V.ccode_convert(*args)
ccode_conv1 = node.field.obj.U._ccode_convert(*args)
ccode_conv2 = node.field.obj.V._ccode_convert(*args)
statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")]
else:
statements = []
if node.convert and node.field.obj.vector_type == "3D":
ccode_conv3 = node.field.obj.W.ccode_convert(*args)
ccode_conv3 = node.field.obj.W._ccode_convert(*args)
statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}"))
conv_stat = c.Block(statements)
node.ccode = c.Block(
Expand All @@ -864,8 +864,8 @@ def visit_NestedFieldEvalNode(self, node):
cstat = []
args = self._check_FieldSamplingArguments(node.args.ccode)
for fld in node.fields.obj:
ccode_eval = fld.ccode_eval(node.var, *args)
ccode_conv = fld.ccode_convert(*args)
ccode_eval = fld._ccode_eval(node.var, *args)
ccode_conv = fld._ccode_convert(*args)
conv_stat = c.Statement(f"{node.var} *= {ccode_conv}")
cstat += [
c.Assign("particles->state[pnum]", ccode_eval),
Expand All @@ -884,15 +884,15 @@ def visit_NestedVectorFieldEvalNode(self, node):
cstat = []
args = self._check_FieldSamplingArguments(node.args.ccode)
for fld in node.fields.obj:
ccode_eval = fld.ccode_eval(node.var, node.var2, node.var3, fld.U, fld.V, fld.W, *args)
ccode_eval = fld._ccode_eval(node.var, node.var2, node.var3, fld.U, fld.V, fld.W, *args)
if fld.U.interp_method != "cgrid_velocity":
ccode_conv1 = fld.U.ccode_convert(*args)
ccode_conv2 = fld.V.ccode_convert(*args)
ccode_conv1 = fld.U._ccode_convert(*args)
ccode_conv2 = fld.V._ccode_convert(*args)
statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")]
else:
statements = []
if fld.vector_type == "3D":
ccode_conv3 = fld.W.ccode_convert(*args)
ccode_conv3 = fld.W._ccode_convert(*args)
statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}"))
cstat += [
c.Assign("particles->state[pnum]", ccode_eval),
Expand Down
Loading

0 comments on commit 018be97

Please sign in to comment.