Skip to content

Commit

Permalink
Merge pull request #1641 from OceanParcels/croco_3D_velocities
Browse files Browse the repository at this point in the history
Support for CROCO 3D velocities
  • Loading branch information
erikvansebille authored Oct 16, 2024
2 parents 6e72612 + e73b275 commit 0ce0650
Show file tree
Hide file tree
Showing 16 changed files with 666 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/documentation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Parcels has several documentation and tutorial Jupyter notebooks and scripts whi
../examples/documentation_indexing.ipynb
../examples/tutorial_nemo_curvilinear.ipynb
../examples/tutorial_nemo_3D.ipynb
../examples/tutorial_croco_3D.ipynb
../examples/tutorial_NestedFields.ipynb
../examples/tutorial_timevaryingdepthdimensions.ipynb
../examples/tutorial_periodic_boundaries.ipynb
Expand Down
332 changes: 332 additions & 0 deletions docs/examples/tutorial_croco_3D.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions parcels/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class ParcelsAST(ast.AST):
) # corresponds with `interp_method` (which can also be dict mapping field names to method)
PathLike = str | os.PathLike
Mesh = Literal["spherical", "flat"] # corresponds with `mesh`
VectorType = Literal["3D", "2D"] | None # corresponds with `vector_type`
VectorType = Literal["3D", "3DSigma", "2D"] | None # corresponds with `vector_type`
ChunkMode = Literal["auto", "specific", "failsafe"] # corresponds with `chunk_mode`
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo"] # corresponds with `gridindexingtype`
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo", "croco"] # corresponds with `gridindexingtype`
UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `_update_status`
TimePeriodic = float | datetime.timedelta | Literal[False] # corresponds with `time_periodic`
NetcdfEngine = Literal["netcdf4", "xarray", "scipy"]
Expand Down
54 changes: 53 additions & 1 deletion parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

from parcels.tools.statuscodes import StatusCode

__all__ = ["AdvectionRK4", "AdvectionEE", "AdvectionRK45", "AdvectionRK4_3D", "AdvectionAnalytical"]
__all__ = [
"AdvectionRK4",
"AdvectionEE",
"AdvectionRK45",
"AdvectionRK4_3D",
"AdvectionAnalytical",
"AdvectionRK4_3D_CROCO",
]


def AdvectionRK4(particle, fieldset, time):
Expand Down Expand Up @@ -40,6 +47,51 @@ def AdvectionRK4_3D(particle, fieldset, time):
particle_ddepth += (w1 + 2 * w2 + 2 * w3 + w4) / 6 * particle.dt # noqa


def AdvectionRK4_3D_CROCO(particle, fieldset, time):
"""Advection of particles using fourth-order Runge-Kutta integration including vertical velocity.
This kernel assumes the vertical velocity is the 'w' field from CROCO output and works on sigma-layers.
"""
sig_dep = particle.depth / fieldset.H[time, 0, particle.lat, particle.lon]

(u1, v1, w1) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon, particle]
w1 *= sig_dep / fieldset.H[time, 0, particle.lat, particle.lon]
lon1 = particle.lon + u1 * 0.5 * particle.dt
lat1 = particle.lat + v1 * 0.5 * particle.dt
sig_dep1 = sig_dep + w1 * 0.5 * particle.dt
dep1 = sig_dep1 * fieldset.H[time, 0, lat1, lon1]

(u2, v2, w2) = fieldset.UVW[time + 0.5 * particle.dt, dep1, lat1, lon1, particle]
w2 *= sig_dep1 / fieldset.H[time, 0, lat1, lon1]
lon2 = particle.lon + u2 * 0.5 * particle.dt
lat2 = particle.lat + v2 * 0.5 * particle.dt
sig_dep2 = sig_dep + w2 * 0.5 * particle.dt
dep2 = sig_dep2 * fieldset.H[time, 0, lat2, lon2]

(u3, v3, w3) = fieldset.UVW[time + 0.5 * particle.dt, dep2, lat2, lon2, particle]
w3 *= sig_dep2 / fieldset.H[time, 0, lat2, lon2]
lon3 = particle.lon + u3 * particle.dt
lat3 = particle.lat + v3 * particle.dt
sig_dep3 = sig_dep + w3 * particle.dt
dep3 = sig_dep3 * fieldset.H[time, 0, lat3, lon3]

(u4, v4, w4) = fieldset.UVW[time + particle.dt, dep3, lat3, lon3, particle]
w4 *= sig_dep3 / fieldset.H[time, 0, lat3, lon3]
lon4 = particle.lon + u4 * particle.dt
lat4 = particle.lat + v4 * particle.dt
sig_dep4 = sig_dep + w4 * particle.dt
dep4 = sig_dep4 * fieldset.H[time, 0, lat4, lon4]

particle_dlon += (u1 + 2 * u2 + 2 * u3 + u4) / 6 * particle.dt # noqa
particle_dlat += (v1 + 2 * v2 + 2 * v3 + v4) / 6 * particle.dt # noqa
particle_ddepth += ( # noqa
(dep1 - particle.depth) * 2
+ 2 * (dep2 - particle.depth) * 2
+ 2 * (dep3 - particle.depth)
+ dep4
- particle.depth
) / 6


def AdvectionEE(particle, fieldset, time):
"""Advection of particles using Explicit Euler (aka Euler Forward) integration."""
(u1, v1) = fieldset.UV[particle]
Expand Down
52 changes: 40 additions & 12 deletions parcels/compilation/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,13 @@ def __init__(self, field):


class VectorFieldEvalNode(IntrinsicNode):
def __init__(self, field, args, var, var2, var3, convert=True):
def __init__(self, field, args, var, var2, var3, var4, convert=True):
self.field = field
self.args = args
self.var = var # the variable in which the interpolated field is written
self.var2 = var2 # second variable for UV interpolation
self.var3 = var3 # third variable for UVW interpolation
self.var4 = var4 # extra variable for sigma-scaling for croco
self.convert = convert # whether to convert the result (like field.applyConversion)


Expand All @@ -107,12 +108,13 @@ def __getitem__(self, attr):


class NestedVectorFieldEvalNode(IntrinsicNode):
def __init__(self, fields, args, var, var2, var3):
def __init__(self, fields, args, var, var2, var3, var4):
self.fields = fields
self.args = args
self.var = var # the variable in which the interpolated field is written
self.var2 = var2 # second variable for UV interpolation
self.var3 = var3 # third variable for UVW interpolation
self.var4 = var4 # extra variable for sigma-scaling for croco


class GridNode(IntrinsicNode):
Expand Down Expand Up @@ -285,9 +287,10 @@ def visit_Subscript(self, node):
elif isinstance(node.value, VectorFieldNode):
tmp = self.get_tmp()
tmp2 = self.get_tmp()
tmp3 = self.get_tmp() if node.value.obj.vector_type == "3D" else None
tmp3 = self.get_tmp() if "3D" in node.value.obj.vector_type else None
tmp4 = self.get_tmp() if "3DSigma" in node.value.obj.vector_type else None
# Insert placeholder node for field eval ...
self.stmt_stack += [VectorFieldEvalNode(node.value, node.slice, tmp, tmp2, tmp3)]
self.stmt_stack += [VectorFieldEvalNode(node.value, node.slice, tmp, tmp2, tmp3, tmp4)]
# .. and return the name of the temporary that will be populated
if tmp3:
return ast.Tuple([ast.Name(id=tmp), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
Expand All @@ -300,8 +303,9 @@ def visit_Subscript(self, node):
elif isinstance(node.value, NestedVectorFieldNode):
tmp = self.get_tmp()
tmp2 = self.get_tmp()
tmp3 = self.get_tmp() if list.__getitem__(node.value.obj, 0).vector_type == "3D" else None
self.stmt_stack += [NestedVectorFieldEvalNode(node.value, node.slice, tmp, tmp2, tmp3)]
tmp3 = self.get_tmp() if "3D" in list.__getitem__(node.value.obj, 0).vector_type else None
tmp4 = self.get_tmp() if "3DSigma" in list.__getitem__(node.value.obj, 0).vector_type else None
self.stmt_stack += [NestedVectorFieldEvalNode(node.value, node.slice, tmp, tmp2, tmp3, tmp4)]
if tmp3:
return ast.Tuple([ast.Name(id=tmp), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
else:
Expand Down Expand Up @@ -371,7 +375,8 @@ def visit_Call(self, node):
# get a temporary value to assign result to
tmp1 = self.get_tmp()
tmp2 = self.get_tmp()
tmp3 = self.get_tmp() if node.func.field.obj.vector_type == "3D" else None
tmp3 = self.get_tmp() if "3D" in node.func.field.obj.vector_type else None
tmp4 = self.get_tmp() if "3DSigma" in node.func.field.obj.vector_type else None
# whether to convert
convert = True
if "applyConversion" in node.keywords:
Expand All @@ -382,7 +387,7 @@ def visit_Call(self, node):
# convert args to Index(Tuple(*args))
args = ast.Index(value=ast.Tuple(node.args, ast.Load()))

self.stmt_stack += [VectorFieldEvalNode(node.func.field, args, tmp1, tmp2, tmp3, convert)]
self.stmt_stack += [VectorFieldEvalNode(node.func.field, args, tmp1, tmp2, tmp3, tmp4, convert)]
if tmp3:
return ast.Tuple([ast.Name(id=tmp1), ast.Name(id=tmp2), ast.Name(id=tmp3)], ast.Load())
else:
Expand Down Expand Up @@ -421,6 +426,8 @@ def __init__(self, fieldset=None, ptype=JITParticle):
self.fieldset = fieldset
self.ptype = ptype
self.field_args = collections.OrderedDict()
if isinstance(fieldset.U, Field) and fieldset.U.gridindexingtype == "croco" and hasattr(fieldset, "H"):
self.field_args["H"] = fieldset.H # CROCO requires H field
self.vector_field_args = collections.OrderedDict()
self.const_args = collections.OrderedDict()

Expand Down Expand Up @@ -456,7 +463,7 @@ def generate(self, py_ast, funcvars: list[str]):
for kvar in self.kernel_vars + self.array_vars:
if kvar in funcvars:
funcvars.remove(kvar)
self.ccode.body.insert(0, c.Value("int", "parcels_interp_state"))
self.ccode.body.insert(0, c.Statement("int parcels_interp_state = 0"))
if len(funcvars) > 0:
for f in funcvars:
self.ccode.body.insert(0, c.Statement(f"type_coord {f} = 0"))
Expand Down Expand Up @@ -819,6 +826,16 @@ def visit_FieldEvalNode(self, node):
self.visit(node.field)
self.visit(node.args)
args = self._check_FieldSamplingArguments(node.args.ccode)
statements_croco = []
if "croco" in node.field.obj.gridindexingtype and node.field.obj.name != "H":
statements_croco.append(
c.Assign(
"parcels_interp_state",
f"temporal_interpolation({args[3]}, {args[2]}, 0, time, H, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], &{node.var}, LINEAR, {node.field.obj.gridindexingtype.upper()})",
)
)
statements_croco.append(c.Statement(f"{node.var} = {args[1]}/{node.var}"))
args = (args[0], node.var, args[2], args[3])
ccode_eval = node.field.obj._ccode_eval(node.var, *args)
stmts = [
c.Assign("parcels_interp_state", ccode_eval),
Expand All @@ -830,12 +847,22 @@ def visit_FieldEvalNode(self, node):
conv_stat = c.Statement(f"{node.var} *= {ccode_conv}")
stmts += [conv_stat]

node.ccode = c.Block(stmts + [c.Statement("CHECKSTATUS_KERNELLOOP(parcels_interp_state)")])
node.ccode = c.Block(statements_croco + stmts + [c.Statement("CHECKSTATUS_KERNELLOOP(parcels_interp_state)")])

def visit_VectorFieldEvalNode(self, node):
self.visit(node.field)
self.visit(node.args)
args = self._check_FieldSamplingArguments(node.args.ccode)
statements_croco = []
if "3DSigma" in node.field.obj.vector_type:
statements_croco.append(
c.Assign(
"parcels_interp_state",
f"temporal_interpolation({args[3]}, {args[2]}, 0, time, H, &particles->xi[pnum*ngrid], &particles->yi[pnum*ngrid], &particles->zi[pnum*ngrid], &particles->ti[pnum*ngrid], &{node.var}, LINEAR, {node.field.obj.U.gridindexingtype.upper()})",
)
)
statements_croco.append(c.Statement(f"{node.var4} = {args[1]}/{node.var}"))
args = (args[0], node.var4, args[2], args[3])
ccode_eval = node.field.obj._ccode_eval(
node.var, node.var2, node.var3, node.field.obj.U, node.field.obj.V, node.field.obj.W, *args
)
Expand All @@ -845,12 +872,13 @@ def visit_VectorFieldEvalNode(self, node):
statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")]
else:
statements = []
if node.convert and node.field.obj.vector_type == "3D":
if node.convert and "3D" in node.field.obj.vector_type:
ccode_conv3 = node.field.obj.W._ccode_convert(*args)
statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}"))
conv_stat = c.Block(statements)
node.ccode = c.Block(
[
c.Block(statements_croco),
c.Assign("parcels_interp_state", ccode_eval),
c.Assign("particles->state[pnum]", "max(particles->state[pnum], parcels_interp_state)"),
conv_stat,
Expand Down Expand Up @@ -891,7 +919,7 @@ def visit_NestedVectorFieldEvalNode(self, node):
statements = [c.Statement(f"{node.var} *= {ccode_conv1}"), c.Statement(f"{node.var2} *= {ccode_conv2}")]
else:
statements = []
if fld.vector_type == "3D":
if "3D" in fld.vector_type:
ccode_conv3 = fld.W._ccode_convert(*args)
statements.append(c.Statement(f"{node.var3} *= {ccode_conv3}"))
cstat += [
Expand Down
Loading

0 comments on commit 0ce0650

Please sign in to comment.