Skip to content

Commit 51bbd96

Browse files
Function space method to create broken function space. (#4773)
* WithGeometryBase.broken_space * REVERT BEFORE MERGE: broken element fiat branch * use WithGeometry.broken_space internally * Update firedrake/functionspaceimpl.py Co-authored-by: Connor Ward <[email protected]> * test for MixedFunctionSpace.broken_space * cannot use type(self) in type hint * only set V.broken_space().name if V.name is set * transformed reduced functional: only use fs.broken_space for non-dg spaces * Apply suggestion from @JHopeCollins --------- Co-authored-by: Connor Ward <[email protected]>
1 parent d172626 commit 51bbd96

File tree

4 files changed

+60
-27
lines changed

4 files changed

+60
-27
lines changed

firedrake/adjoint/transformed_functional.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -176,28 +176,6 @@ def is_dg_space(space: WithGeometry) -> bool:
176176
return e.is_dg()
177177

178178

179-
def dg_space(space: WithGeometry) -> WithGeometry:
180-
"""Construct a DG space containing a given function space as a subspace.
181-
182-
Parameters
183-
----------
184-
185-
space
186-
A function space.
187-
188-
Returns
189-
-------
190-
191-
firedrake.functionspaceimpl.WithGeometry
192-
A DG space containing `space` as a subspace. May be `space`.
193-
"""
194-
195-
if is_dg_space(space):
196-
return space
197-
else:
198-
return fd.FunctionSpace(space.mesh(), finat.ufl.BrokenElement(space.ufl_element()))
199-
200-
201179
class L2TransformedFunctional(AbstractReducedFunctional):
202180
r"""Represents the functional
203181
@@ -265,7 +243,8 @@ def __init__(self, functional: pyadjoint.OverloadedType, controls: Union[Control
265243
self._space_D = Enlist(space_D)
266244
if len(self._space_D) != len(self._space):
267245
raise ValueError("Invalid length")
268-
self._space_D = tuple(dg_space(space) if space_D is None else space_D
246+
self._space_D = tuple((space if is_dg_space(space) else space.broken_space())
247+
if space_D is None else space_D
269248
for space, space_D in zip(self._space, self._space_D))
270249

271250
self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2")

firedrake/functionspaceimpl.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,19 @@ def make_function_space(cls, mesh, element, name=None):
403403
new = cls.create(new, mesh)
404404
return new
405405

406+
def broken_space(self):
407+
"""Return a :class:`.WithGeometryBase` with a :class:`finat.ufl.BrokenElement`
408+
constructed from this function space's FiniteElement.
409+
410+
Returns
411+
-------
412+
WithGeometryBase :
413+
The new function space with a :class:`~finat.ufl.BrokenElement`.
414+
"""
415+
return type(self).make_function_space(
416+
self.mesh(), finat.ufl.BrokenElement(self.ufl_element()),
417+
name=f"{self.name}_broken" if self.name else None)
418+
406419
def reconstruct(
407420
self,
408421
mesh: MeshGeometry | None = None,

firedrake/slate/static_condensation/hybridization.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import functools
22

33
import ufl
4-
import finat.ufl
54

65
import firedrake.dmhooks as dmhooks
76
from firedrake.slate.static_condensation.sc_base import SCBase
@@ -90,8 +89,7 @@ def initialize(self, pc):
9089
TraceSpace = FunctionSpace(mesh[self.vidx], "HDiv Trace", tdegree)
9190

9291
# Break the function spaces and define fully discontinuous spaces
93-
broken_elements = finat.ufl.MixedElement([finat.ufl.BrokenElement(Vi.ufl_element()) for Vi in V])
94-
V_d = FunctionSpace(mesh, broken_elements)
92+
V_d = V.broken_space()
9593

9694
# Set up the functions for the original, hybridized
9795
# and schur complement systems

tests/firedrake/regression/test_function_spaces.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_function_space_variant(mesh, space):
120120

121121

122122
@pytest.mark.parametrize("modifier",
123-
[BrokenElement, HDivElement,
123+
[HDivElement,
124124
HCurlElement])
125125
@pytest.mark.parametrize("element",
126126
[FiniteElement("CG", triangle, 1),
@@ -313,3 +313,46 @@ def test_reconstruct_sub_component(dg0, rt1, mesh, mesh2, dual):
313313
assert is_primal(V1.parent) == is_primal(V2.parent) != dual
314314
assert V1.parent.ufl_element() == V2.parent.ufl_element()
315315
assert V1.parent.index == V2.parent.index == index
316+
317+
318+
@pytest.mark.parametrize("family", ("CG", "BDM", "DG"))
319+
@pytest.mark.parametrize("shape", (0, 2, (2, 3)), ids=("0", "2", "(2,3)"))
320+
def test_broken_space(mesh, shape, family):
321+
"""Check that FunctionSpace.broken_space returns the a
322+
FunctionSpace with the correct element.
323+
"""
324+
kwargs = {"variant": "spectral"} if family == "DG" else {}
325+
326+
elem = FiniteElement(family, mesh.ufl_cell(), 1, **kwargs)
327+
328+
if not isinstance(shape, int):
329+
make_element = lambda elem: TensorElement(elem, shape=shape)
330+
elif shape > 0:
331+
make_element = lambda elem: VectorElement(elem, dim=shape)
332+
else:
333+
make_element = lambda elem: elem
334+
335+
fs = FunctionSpace(mesh, make_element(elem))
336+
broken = fs.broken_space()
337+
expected = FunctionSpace(mesh, make_element(BrokenElement(elem)))
338+
339+
assert broken == expected
340+
341+
342+
def test_mixed_broken_space(mesh):
343+
"""Check that MixedFunctionSpace.broken_space returns the a
344+
MixedFunctionSpace with the correct element.
345+
"""
346+
347+
mixed_elem = MixedElement([
348+
FiniteElement("CG", mesh.ufl_cell(), 1),
349+
VectorElement("BDM", mesh.ufl_cell(), 2, dim=2),
350+
TensorElement("DG", mesh.ufl_cell(), 1, shape=(2, 3), variant="spectral")
351+
])
352+
broken_elem = MixedElement([BrokenElement(elem) for elem in mixed_elem.sub_elements])
353+
354+
mfs = FunctionSpace(mesh, mixed_elem)
355+
broken = mfs.broken_space()
356+
expected = FunctionSpace(mesh, broken_elem)
357+
358+
assert broken == expected

0 commit comments

Comments
 (0)