Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API changes: Field and FieldSet #1709

Merged
merged 18 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/example_nemo_curvilinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_nemo_3D_samegrid():

fieldset = parcels.FieldSet.from_nemo(filenames, variables, dimensions)

assert fieldset.U.dataFiles is not fieldset.W.dataFiles
assert fieldset.U._dataFiles is not fieldset.W._dataFiles


def main(args=None):
Expand Down
27 changes: 25 additions & 2 deletions parcels/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import datetime
import os
from collections.abc import Callable
from typing import Literal
from typing import Any, Literal, get_args


class ParcelsAST(ast.AST):
Expand All @@ -37,10 +37,33 @@ class ParcelsAST(ast.AST):
Mesh = Literal["spherical", "flat"] # corresponds with `mesh`
VectorType = Literal["3D", "2D"] | None # corresponds with `vector_type`
ChunkMode = Literal["auto", "specific", "failsafe"] # corresponds with `chunk_mode`
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `grid_indexing_type`
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `gridindexingtype`
UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `update_status`
TimePeriodic = float | datetime.timedelta | Literal[False] # corresponds with `update_status`
NetcdfEngine = Literal["netcdf4", "xarray", "scipy"]


KernelFunction = Callable[..., None]


def _validate_against_pure_literal(value, typing_literal):
"""Uses a Literal type alias to validate.

Can't be used with ``Literal[...] | None`` etc. as its not a pure literal.
"""
if value not in get_args(typing_literal):
msg = f"Invalid value {value!r}. Valid options are {get_args(typing_literal)!r}"
raise ValueError(msg)


# Assertion functions to clean user input
def assert_valid_interp_method(value: Any):
_validate_against_pure_literal(value, InterpMethodOption)


def assert_valid_mesh(value: Any):
_validate_against_pure_literal(value, Mesh)


def assert_valid_gridindexingtype(value: Any):
_validate_against_pure_literal(value, GridIndexingType)
4 changes: 2 additions & 2 deletions parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@
direction = 1.0 if particle.dt > 0 else -1.0
withW = True if "W" in [f.name for f in fieldset.get_fields()] else False
withTime = True if len(fieldset.U.grid.time_full) > 1 else False
ti = fieldset.U.time_index(time)[0]
ti = fieldset.U._time_index(time)[0]

Check warning on line 128 in parcels/application_kernels/advection.py

View check run for this annotation

Codecov / codecov/patch

parcels/application_kernels/advection.py#L128

Added line #L128 was not covered by tests
ds_t = particle.dt
if withTime:
tau = (time - fieldset.U.grid.time[ti]) / (fieldset.U.grid.time[ti + 1] - fieldset.U.grid.time[ti])
time_i = np.linspace(0, fieldset.U.grid.time[ti + 1] - fieldset.U.grid.time[ti], I_s)
ds_t = min(ds_t, time_i[np.where(time - fieldset.U.grid.time[ti] < time_i)[0][0]])

xsi, eta, zeta, xi, yi, zi = fieldset.U.search_indices(
xsi, eta, zeta, xi, yi, zi = fieldset.U._search_indices(

Check warning on line 135 in parcels/application_kernels/advection.py

View check run for this annotation

Codecov / codecov/patch

parcels/application_kernels/advection.py#L135

Added line #L135 was not covered by tests
particle.lon, particle.lat, particle.depth, particle=particle
)
if withW:
Expand Down
24 changes: 12 additions & 12 deletions parcels/compilation/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,14 +819,14 @@
self.visit(node.field)
self.visit(node.args)
args = self._check_FieldSamplingArguments(node.args.ccode)
ccode_eval = node.field.obj.ccode_eval(node.var, *args)
ccode_eval = node.field.obj._ccode_eval(node.var, *args)
stmts = [
c.Assign("parcels_interp_state", ccode_eval),
c.Assign("particles->state[pnum]", "max(particles->state[pnum], parcels_interp_state)"),
]

if node.convert:
ccode_conv = node.field.obj.ccode_convert(*args)
ccode_conv = node.field.obj._ccode_convert(*args)
conv_stat = c.Statement(f"{node.var} *= {ccode_conv}")
stmts += [conv_stat]

Expand All @@ -836,17 +836,17 @@
self.visit(node.field)
self.visit(node.args)
args = self._check_FieldSamplingArguments(node.args.ccode)
ccode_eval = node.field.obj.ccode_eval(
ccode_eval = node.field.obj._ccode_eval(
node.var, node.var2, node.var3, node.field.obj.U, node.field.obj.V, node.field.obj.W, *args
)
if node.convert and node.field.obj.U.interp_method != "cgrid_velocity":
ccode_conv1 = node.field.obj.U.ccode_convert(*args)
ccode_conv2 = node.field.obj.V.ccode_convert(*args)
ccode_conv1 = node.field.obj.U._ccode_convert(*args)
ccode_conv2 = node.field.obj.V._ccode_convert(*args)
statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")]
else:
statements = []
if node.convert and node.field.obj.vector_type == "3D":
ccode_conv3 = node.field.obj.W.ccode_convert(*args)
ccode_conv3 = node.field.obj.W._ccode_convert(*args)
statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}"))
conv_stat = c.Block(statements)
node.ccode = c.Block(
Expand All @@ -864,8 +864,8 @@
cstat = []
args = self._check_FieldSamplingArguments(node.args.ccode)
for fld in node.fields.obj:
ccode_eval = fld.ccode_eval(node.var, *args)
ccode_conv = fld.ccode_convert(*args)
ccode_eval = fld._ccode_eval(node.var, *args)
ccode_conv = fld._ccode_convert(*args)
conv_stat = c.Statement(f"{node.var} *= {ccode_conv}")
cstat += [
c.Assign("particles->state[pnum]", ccode_eval),
Expand All @@ -884,15 +884,15 @@
cstat = []
args = self._check_FieldSamplingArguments(node.args.ccode)
for fld in node.fields.obj:
ccode_eval = fld.ccode_eval(node.var, node.var2, node.var3, fld.U, fld.V, fld.W, *args)
ccode_eval = fld._ccode_eval(node.var, node.var2, node.var3, fld.U, fld.V, fld.W, *args)
if fld.U.interp_method != "cgrid_velocity":
ccode_conv1 = fld.U.ccode_convert(*args)
ccode_conv2 = fld.V.ccode_convert(*args)
ccode_conv1 = fld.U._ccode_convert(*args)
ccode_conv2 = fld.V._ccode_convert(*args)
statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")]
else:
statements = []
if fld.vector_type == "3D":
ccode_conv3 = fld.W.ccode_convert(*args)
ccode_conv3 = fld.W._ccode_convert(*args)

Check warning on line 895 in parcels/compilation/codegenerator.py

View check run for this annotation

Codecov / codecov/patch

parcels/compilation/codegenerator.py#L895

Added line #L895 was not covered by tests
statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}"))
cstat += [
c.Assign("particles->state[pnum]", ccode_eval),
Expand Down
Loading
Loading