7
7
from firedrake .cython import hdf5interface as h5i
8
8
from firedrake .cython import dmcommon
9
9
from firedrake .petsc import PETSc , OptionsManager
10
- from firedrake .mesh import MeshTopology , ExtrudedMeshTopology , DEFAULT_MESH_NAME , make_mesh_from_coordinates , DistributedMeshOverlapType
10
+ from firedrake .mesh import MeshGeometry , MeshTopology , ExtrudedMeshTopology , DEFAULT_MESH_NAME , make_mesh_from_coordinates , DistributedMeshOverlapType
11
11
from firedrake .functionspace import FunctionSpace
12
12
from firedrake import functionspaceimpl as impl
13
13
from firedrake .functionspacedata import get_global_numbering , create_element
20
20
import numpy as np
21
21
import os
22
22
import h5py
23
+ from typing import Optional , Union
23
24
24
25
25
26
__all__ = ["DumbCheckpoint" , "HDF5File" , "FILE_READ" , "FILE_CREATE" , "FILE_UPDATE" , "CheckpointFile" ]
@@ -896,25 +897,47 @@ def _save_function_space_topology(self, tV):
896
897
topology_dm .setName (base_tmesh_name )
897
898
898
899
@PETSc .Log .EventDecorator ("SaveFunction" )
899
- def save_function (self , f , idx = None , name = None , timestepping_info = {}):
900
- r"""Save a :class:`~.Function`.
900
+ def save_function (
901
+ self ,
902
+ f : Function ,
903
+ idx : Optional [int ] = None ,
904
+ name : Optional [str ] = None ,
905
+ timestepping_info : Optional [dict ] = {},
906
+ affine_coordinates : Optional [Union [MeshGeometry , Function ]] = None ,
907
+ affine_quadrature_degree : Optional [int ] = None ,
908
+ ) -> None :
909
+ """Save a :class:`~.Function`.
901
910
902
- :arg f: the :class:`~.Function` to save.
903
- :kwarg idx: optional timestepping index. A function can
911
+ Parameters
912
+ ----------
913
+ f
914
+ `Function` to save.
915
+ idx
916
+ Optional timestepping index. A function can
904
917
either be saved in timestepping mode or in normal
905
918
mode (non-timestepping); for each function of interest,
906
919
this method must always be called with the idx parameter
907
920
set or never be called with the idx parameter set.
908
- :kwarg name: optional alternative name to save the function under.
909
- :kwarg timestepping_info: optional (requires idx) additional information
921
+ name
922
+ Optional alternative name to save the function under.
923
+ timestepping_info
924
+ Optional (requires idx) additional information
910
925
such as time, timestepping that can be stored along a function for
911
926
each index.
927
+ affine_coordinates
928
+ Representation of a fictitious affine mesh onto which
929
+ the function is mapped before saving; only significant for
930
+ HDiv/HCurl functions defined on high-order mesh.
931
+ affine_quadrature_degree
932
+ Quadrature degree to be used when mapping onto the affine mesh;
933
+ only significant for HDiv/HCurl functions defined on high-order mesh.
934
+
912
935
"""
913
936
V = f .function_space ()
914
937
mesh = V .mesh ()
915
938
if name :
916
939
g = Function (V , val = f .dat , name = name )
917
- return self .save_function (g , idx = idx , timestepping_info = timestepping_info )
940
+ return self .save_function (g , idx = idx , timestepping_info = timestepping_info , affine_coordinates = affine_coordinates , affine_quadrature_degree = affine_quadrature_degree )
918
941
# -- Save function space --
919
942
self ._save_function_space (V )
920
943
# -- Save function --
@@ -926,7 +949,7 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
926
949
path = os .path .join (base_path , str (i ))
927
950
self .require_group (path )
928
951
self .set_attr (path , PREFIX + "_function" , fsub .name ())
929
- self .save_function (fsub , idx = idx , timestepping_info = timestepping_info )
952
+ self .save_function (fsub , idx = idx , timestepping_info = timestepping_info , affine_coordinates = affine_coordinates , affine_quadrature_degree = affine_quadrature_degree )
930
953
self ._update_mixed_function_name_mixed_function_space_name_map (mesh .name , {f .name (): V_name })
931
954
else :
932
955
tf = f .topological
@@ -940,10 +963,32 @@ def save_function(self, f, idx=None, name=None, timestepping_info={}):
940
963
path = self ._path_to_function_embedded (tmesh .name , mesh .name , V_name , f .name ())
941
964
self .require_group (path )
942
965
method = get_embedding_method_for_checkpointing (element )
943
- _V = FunctionSpace (mesh , _element )
966
+ if mesh .coordinates .function_space ().ufl_element ().embedded_subdegree > 1 :
967
+ # Handle non-affine mesh; this is only relevant when embedding into a DG space.
968
+ if affine_coordinates is None :
969
+ raise ValueError ("Must provide affine_coordinates to save functions on high-order mesh" )
970
+ if affine_quadrature_degree is None :
971
+ raise ValueError ("Must provide affine_quadrature_degree to save functions on high-order mesh" )
972
+ if isinstance (affine_coordinates , MeshGeometry ):
973
+ affine_coordinates = affine_coordinates .coordinates
974
+ else :
975
+ if not isinstance (affine_coordinates , Function ):
976
+ raise ValueError ("affine_coordinates must be {MeshGeometry, Function}" )
977
+ if affine_coordinates .function_space ().mesh ().topology is not tmesh :
978
+ raise ValueError (f"affine_coordinates.function_space().mesh().topology ({ affine_coordinates .function_space ().mesh ().topology } ) is not f.mesh().topology ({ tmesh } )" )
979
+ if affine_coordinates .function_space ().mesh () is not mesh :
980
+ affine_coordinate_V = FunctionSpace (mesh , affine_coordinates .function_space ().ufl_element ())
981
+ affine_coordinates = Function (affine_coordinate_V , val = affine_coordinates .topological )
982
+ if affine_coordinates .topological .name () == mesh .coordinates .topological .name ():
983
+ raise ValueError (f"affine_coordinate.name() ({ affine_coordinates .name ()} ) == mesh.coordinates.topological.name() ({ mesh .coordinates .topological .name ()} )" )
984
+ self ._save_ufl_element (path , PREFIX_EMBEDDED + "_affine_coordinate_element" , affine_coordinates .topological .function_space ().ufl_element ())
985
+ self .set_attr (path , PREFIX_EMBEDDED + "_affine_coordinates" , affine_coordinates .topological .name ())
986
+ self .set_attr (path , PREFIX_EMBEDDED + "_affine_quadrature_degree" , affine_quadrature_degree )
987
+ self ._save_function_topology (affine_coordinates .topological )
944
988
_name = "_" .join ([PREFIX_EMBEDDED , f .name ()])
989
+ _V = FunctionSpace (mesh , _element )
945
990
_f = Function (_V , name = _name )
946
- self ._project_function_for_checkpointing (_f , f , method )
991
+ self ._project_function_for_checkpointing (_f , f , method , affine_coordinates = affine_coordinates , affine_quadrature_degree = affine_quadrature_degree )
947
992
self .save_function (_f , idx = idx , timestepping_info = timestepping_info )
948
993
self .set_attr (path , PREFIX_EMBEDDED + "_function" , _name )
949
994
else :
@@ -1045,35 +1090,41 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
1045
1090
path = self ._path_to_topology_extruded (tmesh_name )
1046
1091
if path in self .h5pyfile :
1047
1092
# -- Load mesh topology --
1048
- base_tmesh_name = self .get_attr (path , PREFIX_EXTRUDED + "_base_mesh" )
1049
- base_tmesh = self ._load_mesh_topology (base_tmesh_name , reorder , distribution_parameters )
1050
- base_tmesh .init ()
1051
- periodic = self .get_attr (path , PREFIX_EXTRUDED + "_periodic" ) if self .has_attr (path , PREFIX_EXTRUDED + "_periodic" ) else False
1052
- variable_layers = self .get_attr (path , PREFIX_EXTRUDED + "_variable_layers" )
1053
- if variable_layers :
1054
- cell = base_tmesh .ufl_cell ()
1055
- element = finat .ufl .VectorElement ("DP" if cell .is_simplex () else "DQ" , cell , 0 , dim = 2 )
1056
- _ = self ._load_function_space_topology (base_tmesh , element )
1057
- base_tmesh_key = self ._generate_mesh_key_from_names (base_tmesh .name ,
1058
- base_tmesh ._distribution_name ,
1059
- base_tmesh ._permutation_name )
1060
- sd_key = self ._get_shared_data_key_for_checkpointing (base_tmesh , element )
1061
- _ , _ , lsf = self ._function_load_utils [base_tmesh_key + sd_key ]
1062
- nroots , _ , _ = lsf .getGraph ()
1063
- layers_a = np .empty (nroots , dtype = utils .IntType )
1064
- layers_a_iset = PETSc .IS ().createGeneral (layers_a , comm = self ._comm )
1065
- layers_a_iset .setName ("_" .join ([PREFIX_EXTRUDED , "layers_iset" ]))
1066
- self .viewer .pushGroup (path )
1067
- layers_a_iset .load (self .viewer )
1068
- self .viewer .popGroup ()
1069
- layers_a = layers_a_iset .getIndices ()
1070
- layers = np .empty ((base_tmesh .cell_set .total_size , 2 ), dtype = utils .IntType )
1071
- unit = MPI ._typedict [np .dtype (utils .IntType ).char ]
1072
- lsf .bcastBegin (unit , layers_a , layers , MPI .REPLACE )
1073
- lsf .bcastEnd (unit , layers_a , layers , MPI .REPLACE )
1093
+ if topology is None :
1094
+ base_tmesh_name = self .get_attr (path , PREFIX_EXTRUDED + "_base_mesh" )
1095
+ base_tmesh = self ._load_mesh_topology (base_tmesh_name , reorder , distribution_parameters )
1096
+ base_tmesh .init ()
1097
+ periodic = self .get_attr (path , PREFIX_EXTRUDED + "_periodic" ) if self .has_attr (path , PREFIX_EXTRUDED + "_periodic" ) else False
1098
+ variable_layers = self .get_attr (path , PREFIX_EXTRUDED + "_variable_layers" )
1099
+ if variable_layers :
1100
+ cell = base_tmesh .ufl_cell ()
1101
+ element = finat .ufl .VectorElement ("DP" if cell .is_simplex () else "DQ" , cell , 0 , dim = 2 )
1102
+ _ = self ._load_function_space_topology (base_tmesh , element )
1103
+ base_tmesh_key = self ._generate_mesh_key_from_names (base_tmesh .name ,
1104
+ base_tmesh ._distribution_name ,
1105
+ base_tmesh ._permutation_name )
1106
+ sd_key = self ._get_shared_data_key_for_checkpointing (base_tmesh , element )
1107
+ _ , _ , lsf = self ._function_load_utils [base_tmesh_key + sd_key ]
1108
+ nroots , _ , _ = lsf .getGraph ()
1109
+ layers_a = np .empty (nroots , dtype = utils .IntType )
1110
+ layers_a_iset = PETSc .IS ().createGeneral (layers_a , comm = self ._comm )
1111
+ layers_a_iset .setName ("_" .join ([PREFIX_EXTRUDED , "layers_iset" ]))
1112
+ self .viewer .pushGroup (path )
1113
+ layers_a_iset .load (self .viewer )
1114
+ self .viewer .popGroup ()
1115
+ layers_a = layers_a_iset .getIndices ()
1116
+ layers = np .empty ((base_tmesh .cell_set .total_size , 2 ), dtype = utils .IntType )
1117
+ unit = MPI ._typedict [np .dtype (utils .IntType ).char ]
1118
+ lsf .bcastBegin (unit , layers_a , layers , MPI .REPLACE )
1119
+ lsf .bcastEnd (unit , layers_a , layers , MPI .REPLACE )
1120
+ else :
1121
+ layers = self .get_attr (path , PREFIX_EXTRUDED + "_layers" )
1122
+ tmesh = ExtrudedMeshTopology (base_tmesh , layers , periodic = periodic , name = tmesh_name )
1074
1123
else :
1075
- layers = self .get_attr (path , PREFIX_EXTRUDED + "_layers" )
1076
- tmesh = ExtrudedMeshTopology (base_tmesh , layers , periodic = periodic , name = tmesh_name )
1124
+ if topology .name != tmesh_name :
1125
+ raise RuntimeError (f"Got wrong mesh topology (f{ topology .name } ): expecting f{ tmesh_name } " )
1126
+ tmesh = topology
1127
+ base_tmesh = topology ._base_mesh
1077
1128
# -- Load mesh --
1078
1129
path = self ._path_to_mesh (tmesh_name , name )
1079
1130
coord_element = self ._load_ufl_element (path , PREFIX + "_coordinate_element" )
@@ -1301,14 +1352,29 @@ def _load_function_space_topology(self, tmesh, element):
1301
1352
return impl .FunctionSpace (tmesh , element )
1302
1353
1303
1354
@PETSc .Log .EventDecorator ("LoadFunction" )
1304
- def load_function (self , mesh , name , idx = None ):
1305
- r"""Load a :class:`~.Function` defined on `mesh`.
1355
+ def load_function (
1356
+ self ,
1357
+ mesh : MeshGeometry ,
1358
+ name : str ,
1359
+ idx : Optional [int ] = None
1360
+ ) -> Function :
1361
+ """Load a :class:`~.Function` defined on ``mesh``.
1306
1362
1307
- :arg mesh: the mesh on which the function is defined.
1308
- :arg name: the name of the :class:`~.Function` to load.
1309
- :kwarg idx: optional timestepping index. A function can
1363
+ Parameters
1364
+ ----------
1365
+ mesh
1366
+ mesh on which the function is defined.
1367
+ name
1368
+ name of the `Function` to load.
1369
+ idx
1370
+ Optional timestepping index. A function can
1310
1371
be loaded with idx only when it was saved with idx.
1311
- :returns: the loaded :class:`~.Function`.
1372
+
1373
+ Returns
1374
+ -------
1375
+ Function
1376
+ Loaded `Function`.
1377
+
1312
1378
"""
1313
1379
tmesh = mesh .topology
1314
1380
if name in self ._get_mixed_function_name_mixed_function_space_name_map (mesh .name ):
@@ -1341,7 +1407,19 @@ def load_function(self, mesh, name, idx=None):
1341
1407
method = get_embedding_method_for_checkpointing (element )
1342
1408
assert _element == _f .function_space ().ufl_element ()
1343
1409
f = Function (V , name = name )
1344
- self ._project_function_for_checkpointing (f , _f , method )
1410
+ if mesh .coordinates .function_space ().ufl_element ().embedded_subdegree > 1 and \
1411
+ self .has_attr (path , PREFIX_EMBEDDED + "_affine_coordinates" ):
1412
+ # Handle non-affine mesh; this is only relevant when embedding into a DG space.
1413
+ affine_coord_element = self ._load_ufl_element (path , PREFIX_EMBEDDED + "_affine_coordinate_element" )
1414
+ affine_coord_name = self .get_attr (path , PREFIX_EMBEDDED + "_affine_coordinates" )
1415
+ affine_quadrature_degree = self .get_attr (path , PREFIX_EMBEDDED + "_affine_quadrature_degree" )
1416
+ affine_coordinates = self ._load_function_topology (tmesh , affine_coord_element , affine_coord_name )
1417
+ affine_coordinate_V = FunctionSpace (mesh , affine_coordinates .function_space ().ufl_element ())
1418
+ affine_coordinates = Function (affine_coordinate_V , val = affine_coordinates .topological )
1419
+ else :
1420
+ affine_coordinates = None
1421
+ affine_quadrature_degree = None
1422
+ self ._project_function_for_checkpointing (f , _f , method , affine_coordinates = affine_coordinates , affine_quadrature_degree = affine_quadrature_degree )
1345
1423
return f
1346
1424
else :
1347
1425
tf_name = self .get_attr (path , PREFIX + "_vec" )
@@ -1637,13 +1715,52 @@ def _is_mixed_function_space(self, mesh_name, V_name):
1637
1715
return True
1638
1716
return False
1639
1717
1640
- def _project_function_for_checkpointing (self , f , _f , method ):
1641
- if method == "project" :
1642
- getattr (f , method )(_f , solver_parameters = {"ksp_rtol" : 1.e-16 })
1643
- elif method == "interpolate" :
1644
- getattr (f , method )(_f )
1718
+ def _project_function_for_checkpointing (self , target , source , method , affine_coordinates = None , affine_quadrature_degree = None ):
1719
+ if affine_coordinates :
1720
+ if affine_quadrature_degree is None :
1721
+ raise ValueError ("Need affine_quadrature_degree to save/load HDiv/HCurl functions on high-order mesh" )
1722
+ # Need to map to/from the representation on a fictitious
1723
+ # affine mesh represented by affine_coordinates.
1724
+ K = firedrake .grad (affine_coordinates ) # K = (\partial X /\partial x) = F^-1.
1725
+ from_elem = source .function_space ().ufl_element ()
1726
+ to_elem = target .function_space ().ufl_element ()
1727
+ if to_elem .mapping () == "identity" :
1728
+ if from_elem .mapping () == "covariant Piola" :
1729
+ source = firedrake .transpose (firedrake .inv (K )) * source
1730
+ elif from_elem .mapping () == "contravariant Piola" :
1731
+ source = 1. / firedrake .det (K ) * K * source
1732
+ else :
1733
+ raise NotImplementedError (f"Unsupported pair: ({ from_elem .mapping ()} , { to_elem .mapping ()} )" )
1734
+ V = target .function_space ()
1735
+ u = firedrake .TrialFunction (V )
1736
+ v = firedrake .TestFunction (V )
1737
+ # Solve projection problem on the fictitious affine mesh.
1738
+ a = firedrake .inner (u , v ) * firedrake .det (K ) * firedrake .dx (degree = affine_quadrature_degree )
1739
+ L = firedrake .inner (source , v ) * firedrake .det (K ) * firedrake .dx (degree = affine_quadrature_degree )
1740
+ firedrake .solve (a == L , target , solver_parameters = {"ksp_rtol" : 1.e-16 })
1741
+ elif from_elem .mapping () == "identity" :
1742
+ if to_elem .mapping () == "covariant Piola" :
1743
+ source = firedrake .transpose (K ) * source
1744
+ elif to_elem .mapping () == "contravariant Piola" :
1745
+ source = firedrake .det (K ) * firedrake .inv (K ) * source
1746
+ else :
1747
+ raise NotImplementedError (f"Unsupported pair: ({ from_elem .mapping ()} , { to_elem .mapping ()} )" )
1748
+ V = target .function_space ()
1749
+ u = firedrake .TrialFunction (V )
1750
+ v = firedrake .TestFunction (V )
1751
+ # Solve projection problem on the high-order mesh.
1752
+ a = firedrake .inner (u , v ) * firedrake .dx (degree = affine_quadrature_degree )
1753
+ L = firedrake .inner (source , v ) * firedrake .dx (degree = affine_quadrature_degree )
1754
+ firedrake .solve (a == L , target , solver_parameters = {"ksp_rtol" : 1.e-16 })
1755
+ else :
1756
+ raise NotImplementedError (f"Unsupported pair: ({ from_elem .mapping ()} , { to_elem .mapping ()} )" )
1645
1757
else :
1646
- raise ValueError (f"Unknown method for projecting: { method } " )
1758
+ if method == "project" :
1759
+ getattr (target , method )(source , solver_parameters = {"ksp_rtol" : 1.e-16 })
1760
+ elif method == "interpolate" :
1761
+ getattr (target , method )(source )
1762
+ else :
1763
+ raise ValueError (f"Unknown method for projecting: { method } " )
1647
1764
1648
1765
@property
1649
1766
def h5pyfile (self ):
0 commit comments