Skip to content

Commit 4866746

Browse files
committed
io: fix saving/loading of HDiv/HCurl functions on a high-order mesh
1 parent 9c5ec2f commit 4866746

File tree

2 files changed

+246
-55
lines changed

2 files changed

+246
-55
lines changed

firedrake/checkpointing.py

+169-52
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from firedrake.cython import hdf5interface as h5i
88
from firedrake.cython import dmcommon
99
from firedrake.petsc import PETSc, OptionsManager
10-
from firedrake.mesh import MeshTopology, ExtrudedMeshTopology, DEFAULT_MESH_NAME, make_mesh_from_coordinates, DistributedMeshOverlapType
10+
from firedrake.mesh import MeshGeometry, MeshTopology, ExtrudedMeshTopology, DEFAULT_MESH_NAME, make_mesh_from_coordinates, DistributedMeshOverlapType
1111
from firedrake.functionspace import FunctionSpace
1212
from firedrake import functionspaceimpl as impl
1313
from firedrake.functionspacedata import get_global_numbering, create_element
@@ -20,6 +20,7 @@
2020
import numpy as np
2121
import os
2222
import h5py
23+
from typing import Optional, Union
2324

2425

2526
__all__ = ["DumbCheckpoint", "HDF5File", "FILE_READ", "FILE_CREATE", "FILE_UPDATE", "CheckpointFile"]
@@ -896,25 +897,47 @@ def _save_function_space_topology(self, tV):
896897
topology_dm.setName(base_tmesh_name)
897898

898899
@PETSc.Log.EventDecorator("SaveFunction")
899-
def save_function(self, f, idx=None, name=None, timestepping_info={}):
900-
r"""Save a :class:`~.Function`.
900+
def save_function(
901+
self,
902+
f: Function,
903+
idx: Optional[int] = None,
904+
name: Optional[str] = None,
905+
timestepping_info: Optional[dict] = {},
906+
affine_coordinates: Optional[Union[MeshGeometry, Function]] = None,
907+
affine_quadrature_degree: Optional[int] = None,
908+
) -> None:
909+
"""Save a :class:`~.Function`.
901910
902-
:arg f: the :class:`~.Function` to save.
903-
:kwarg idx: optional timestepping index. A function can
911+
Parameters
912+
----------
913+
f
914+
`Function` to save.
915+
idx
916+
Optional timestepping index. A function can
904917
either be saved in timestepping mode or in normal
905918
mode (non-timestepping); for each function of interest,
906919
this method must always be called with the idx parameter
907920
set or never be called with the idx parameter set.
908-
:kwarg name: optional alternative name to save the function under.
909-
:kwarg timestepping_info: optional (requires idx) additional information
921+
name
922+
Optional alternative name to save the function under.
923+
timestepping_info
924+
Optional (requires idx) additional information
910925
such as time, timestepping that can be stored along a function for
911926
each index.
927+
affine_coordinates
928+
Representation of a fictitious affine mesh onto which
929+
the function is mapped before saving; only significant for
930+
HDiv/HCurl functions defined on high-order mesh.
931+
affine_quadrature_degree
932+
Quadrature degree to be used when mapping onto the affine mesh;
933+
only significant for HDiv/HCurl functions defined on high-order mesh.
934+
912935
"""
913936
V = f.function_space()
914937
mesh = V.mesh()
915938
if name:
916939
g = Function(V, val=f.dat, name=name)
917-
return self.save_function(g, idx=idx, timestepping_info=timestepping_info)
940+
return self.save_function(g, idx=idx, timestepping_info=timestepping_info, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
918941
# -- Save function space --
919942
self._save_function_space(V)
920943
# -- Save function --
@@ -926,7 +949,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
926949
path = os.path.join(base_path, str(i))
927950
self.require_group(path)
928951
self.set_attr(path, PREFIX + "_function", fsub.name())
929-
self.save_function(fsub, idx=idx, timestepping_info=timestepping_info)
952+
self.save_function(fsub, idx=idx, timestepping_info=timestepping_info, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
930953
self._update_mixed_function_name_mixed_function_space_name_map(mesh.name, {f.name(): V_name})
931954
else:
932955
tf = f.topological
@@ -940,10 +963,32 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
940963
path = self._path_to_function_embedded(tmesh.name, mesh.name, V_name, f.name())
941964
self.require_group(path)
942965
method = get_embedding_method_for_checkpointing(element)
943-
_V = FunctionSpace(mesh, _element)
966+
if mesh.coordinates.function_space().ufl_element().embedded_subdegree > 1:
967+
# Handle non-affine mesh; this is only relevant when embedding into a DG space.
968+
if affine_coordinates is None:
969+
raise ValueError("Must provide affine_coordinates to save functions on high-order mesh")
970+
if affine_quadrature_degree is None:
971+
raise ValueError("Must provide affine_quadrature_degree to save functions on high-order mesh")
972+
if isinstance(affine_coordinates, MeshGeometry):
973+
affine_coordinates = affine_coordinates.coordinates
974+
else:
975+
if not isinstance(affine_coordinates, Function):
976+
raise ValueError("affine_coordinates must be {MeshGeometry, Function}")
977+
if affine_coordinates.function_space().mesh().topology is not tmesh:
978+
raise ValueError(f"affine_coordinates.function_space().mesh().topology ({affine_coordinates.function_space().mesh().topology}) is not f.mesh().topology ({tmesh})")
979+
if affine_coordinates.function_space().mesh() is not mesh:
980+
affine_coordinate_V = FunctionSpace(mesh, affine_coordinates.function_space().ufl_element())
981+
affine_coordinates = Function(affine_coordinate_V, val=affine_coordinates.topological)
982+
if affine_coordinates.topological.name() == mesh.coordinates.topological.name():
983+
raise ValueError(f"affine_coordinate.name() ({affine_coordinates.name()}) == mesh.coordinates.topological.name() ({mesh.coordinates.topological.name()})")
984+
self._save_ufl_element(path, PREFIX_EMBEDDED + "_affine_coordinate_element", affine_coordinates.topological.function_space().ufl_element())
985+
self.set_attr(path, PREFIX_EMBEDDED + "_affine_coordinates", affine_coordinates.topological.name())
986+
self.set_attr(path, PREFIX_EMBEDDED + "_affine_quadrature_degree", affine_quadrature_degree)
987+
self._save_function_topology(affine_coordinates.topological)
944988
_name = "_".join([PREFIX_EMBEDDED, f.name()])
989+
_V = FunctionSpace(mesh, _element)
945990
_f = Function(_V, name=_name)
946-
self._project_function_for_checkpointing(_f, f, method)
991+
self._project_function_for_checkpointing(_f, f, method, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
947992
self.save_function(_f, idx=idx, timestepping_info=timestepping_info)
948993
self.set_attr(path, PREFIX_EMBEDDED + "_function", _name)
949994
else:
@@ -1045,35 +1090,41 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
10451090
path = self._path_to_topology_extruded(tmesh_name)
10461091
if path in self.h5pyfile:
10471092
# -- Load mesh topology --
1048-
base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
1049-
base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters)
1050-
base_tmesh.init()
1051-
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
1052-
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
1053-
if variable_layers:
1054-
cell = base_tmesh.ufl_cell()
1055-
element = finat.ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
1056-
_ = self._load_function_space_topology(base_tmesh, element)
1057-
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
1058-
base_tmesh._distribution_name,
1059-
base_tmesh._permutation_name)
1060-
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
1061-
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
1062-
nroots, _, _ = lsf.getGraph()
1063-
layers_a = np.empty(nroots, dtype=utils.IntType)
1064-
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
1065-
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
1066-
self.viewer.pushGroup(path)
1067-
layers_a_iset.load(self.viewer)
1068-
self.viewer.popGroup()
1069-
layers_a = layers_a_iset.getIndices()
1070-
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
1071-
unit = MPI._typedict[np.dtype(utils.IntType).char]
1072-
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
1073-
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
1093+
if topology is None:
1094+
base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
1095+
base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters)
1096+
base_tmesh.init()
1097+
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
1098+
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
1099+
if variable_layers:
1100+
cell = base_tmesh.ufl_cell()
1101+
element = finat.ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
1102+
_ = self._load_function_space_topology(base_tmesh, element)
1103+
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
1104+
base_tmesh._distribution_name,
1105+
base_tmesh._permutation_name)
1106+
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
1107+
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
1108+
nroots, _, _ = lsf.getGraph()
1109+
layers_a = np.empty(nroots, dtype=utils.IntType)
1110+
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
1111+
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
1112+
self.viewer.pushGroup(path)
1113+
layers_a_iset.load(self.viewer)
1114+
self.viewer.popGroup()
1115+
layers_a = layers_a_iset.getIndices()
1116+
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
1117+
unit = MPI._typedict[np.dtype(utils.IntType).char]
1118+
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
1119+
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
1120+
else:
1121+
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
1122+
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
10741123
else:
1075-
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
1076-
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
1124+
if topology.name != tmesh_name:
1125+
raise RuntimeError(f"Got wrong mesh topology (f{topology.name}): expecting f{tmesh_name}")
1126+
tmesh = topology
1127+
base_tmesh = topology._base_mesh
10771128
# -- Load mesh --
10781129
path = self._path_to_mesh(tmesh_name, name)
10791130
coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element")
@@ -1301,14 +1352,29 @@ def _load_function_space_topology(self, tmesh, element):
13011352
return impl.FunctionSpace(tmesh, element)
13021353

13031354
@PETSc.Log.EventDecorator("LoadFunction")
1304-
def load_function(self, mesh, name, idx=None):
1305-
r"""Load a :class:`~.Function` defined on `mesh`.
1355+
def load_function(
1356+
self,
1357+
mesh: MeshGeometry,
1358+
name: str,
1359+
idx: Optional[int] = None
1360+
) -> Function:
1361+
"""Load a :class:`~.Function` defined on ``mesh``.
13061362
1307-
:arg mesh: the mesh on which the function is defined.
1308-
:arg name: the name of the :class:`~.Function` to load.
1309-
:kwarg idx: optional timestepping index. A function can
1363+
Parameters
1364+
----------
1365+
mesh
1366+
mesh on which the function is defined.
1367+
name
1368+
name of the `Function` to load.
1369+
idx
1370+
Optional timestepping index. A function can
13101371
be loaded with idx only when it was saved with idx.
1311-
:returns: the loaded :class:`~.Function`.
1372+
1373+
Returns
1374+
-------
1375+
Function
1376+
Loaded `Function`.
1377+
13121378
"""
13131379
tmesh = mesh.topology
13141380
if name in self._get_mixed_function_name_mixed_function_space_name_map(mesh.name):
@@ -1341,7 +1407,19 @@ def load_function(self, mesh, name, idx=None):
13411407
method = get_embedding_method_for_checkpointing(element)
13421408
assert _element == _f.function_space().ufl_element()
13431409
f = Function(V, name=name)
1344-
self._project_function_for_checkpointing(f, _f, method)
1410+
if mesh.coordinates.function_space().ufl_element().embedded_subdegree > 1 and \
1411+
self.has_attr(path, PREFIX_EMBEDDED + "_affine_coordinates"):
1412+
# Handle non-affine mesh; this is only relevant when embedding into a DG space.
1413+
affine_coord_element = self._load_ufl_element(path, PREFIX_EMBEDDED + "_affine_coordinate_element")
1414+
affine_coord_name = self.get_attr(path, PREFIX_EMBEDDED + "_affine_coordinates")
1415+
affine_quadrature_degree = self.get_attr(path, PREFIX_EMBEDDED + "_affine_quadrature_degree")
1416+
affine_coordinates = self._load_function_topology(tmesh, affine_coord_element, affine_coord_name)
1417+
affine_coordinate_V = FunctionSpace(mesh, affine_coordinates.function_space().ufl_element())
1418+
affine_coordinates = Function(affine_coordinate_V, val=affine_coordinates.topological)
1419+
else:
1420+
affine_coordinates = None
1421+
affine_quadrature_degree = None
1422+
self._project_function_for_checkpointing(f, _f, method, affine_coordinates=affine_coordinates, affine_quadrature_degree=affine_quadrature_degree)
13451423
return f
13461424
else:
13471425
tf_name = self.get_attr(path, PREFIX + "_vec")
@@ -1637,13 +1715,52 @@ def _is_mixed_function_space(self, mesh_name, V_name):
16371715
return True
16381716
return False
16391717

1640-
def _project_function_for_checkpointing(self, f, _f, method):
1641-
if method == "project":
1642-
getattr(f, method)(_f, solver_parameters={"ksp_rtol": 1.e-16})
1643-
elif method == "interpolate":
1644-
getattr(f, method)(_f)
1718+
def _project_function_for_checkpointing(self, target, source, method, affine_coordinates=None, affine_quadrature_degree=None):
1719+
if affine_coordinates:
1720+
if affine_quadrature_degree is None:
1721+
raise ValueError("Need affine_quadrature_degree to save/load HDiv/HCurl functions on high-order mesh")
1722+
# Need to map to/from the representation on a fictitious
1723+
# affine mesh represented by affine_coordinates.
1724+
K = firedrake.grad(affine_coordinates) # K = (\partial X /\partial x) = F^-1.
1725+
from_elem = source.function_space().ufl_element()
1726+
to_elem = target.function_space().ufl_element()
1727+
if to_elem.mapping() == "identity":
1728+
if from_elem.mapping() == "covariant Piola":
1729+
source = firedrake.transpose(firedrake.inv(K)) * source
1730+
elif from_elem.mapping() == "contravariant Piola":
1731+
source = 1. / firedrake.det(K) * K * source
1732+
else:
1733+
raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})")
1734+
V = target.function_space()
1735+
u = firedrake.TrialFunction(V)
1736+
v = firedrake.TestFunction(V)
1737+
# Solve projection problem on the fictitious affine mesh.
1738+
a = firedrake.inner(u, v) * firedrake.det(K) * firedrake.dx(degree=affine_quadrature_degree)
1739+
L = firedrake.inner(source, v) * firedrake.det(K) * firedrake.dx(degree=affine_quadrature_degree)
1740+
firedrake.solve(a == L, target, solver_parameters={"ksp_rtol": 1.e-16})
1741+
elif from_elem.mapping() == "identity":
1742+
if to_elem.mapping() == "covariant Piola":
1743+
source = firedrake.transpose(K) * source
1744+
elif to_elem.mapping() == "contravariant Piola":
1745+
source = firedrake.det(K) * firedrake.inv(K) * source
1746+
else:
1747+
raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})")
1748+
V = target.function_space()
1749+
u = firedrake.TrialFunction(V)
1750+
v = firedrake.TestFunction(V)
1751+
# Solve projection problem on the high-order mesh.
1752+
a = firedrake.inner(u, v) * firedrake.dx(degree=affine_quadrature_degree)
1753+
L = firedrake.inner(source, v) * firedrake.dx(degree=affine_quadrature_degree)
1754+
firedrake.solve(a == L, target, solver_parameters={"ksp_rtol": 1.e-16})
1755+
else:
1756+
raise NotImplementedError(f"Unsupported pair: ({from_elem.mapping()}, {to_elem.mapping()})")
16451757
else:
1646-
raise ValueError(f"Unknown method for projecting: {method}")
1758+
if method == "project":
1759+
getattr(target, method)(source, solver_parameters={"ksp_rtol": 1.e-16})
1760+
elif method == "interpolate":
1761+
getattr(target, method)(source)
1762+
else:
1763+
raise ValueError(f"Unknown method for projecting: {method}")
16471764

16481765
@property
16491766
def h5pyfile(self):

0 commit comments

Comments
 (0)