Skip to content

Commit cd60b83

Browse files
authored
interpolate two-forms (#4770)
1 parent d5e74a6 commit cd60b83

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

firedrake/assemble.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,15 @@ def restructure_base_form(expr, visited=None):
855855
if isinstance(expr, ufl.FormSum) and all(ufl.duals.is_dual(a.function_space()) for a in expr.arguments()):
856856
# Return ufl.Sum if we are assembling a FormSum with Coarguments (a primal expression)
857857
return sum(w*c for w, c in zip(expr.weights(), expr.components()))
858+
859+
# If F: V3 x V2 -> R, then
860+
# Interpolate(TestFunction(V1), F) <=> Action(Interpolate(TestFunction(V1), TrialFunction(V2.dual())), F).
861+
# The result is a two-form V3 x V1 -> R.
862+
if isinstance(expr, ufl.Interpolate) and isinstance(expr.argument_slots()[0], ufl.form.Form):
863+
form, operand = expr.argument_slots()
864+
vstar = firedrake.Argument(form.arguments()[0].function_space().dual(), 1)
865+
expr = expr._ufl_expr_reconstruct_(operand, v=vstar)
866+
return ufl.action(expr, form)
858867
return expr
859868

860869
@staticmethod

tests/firedrake/regression/test_interpolate.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
cwd = abspath(dirname(__file__))
77

88

9+
def mat_equals(a, b):
10+
"""Check that two Matrices are equal."""
11+
a = a.petscmat.copy()
12+
a.axpy(-1.0, b.petscmat)
13+
return a.norm(norm_type=PETSc.NormType.NORM_FROBENIUS) < 1e-14
14+
15+
916
def test_constant():
1017
cg1 = FunctionSpace(UnitSquareMesh(5, 5), "CG", 1)
1118
f = assemble(interpolate(Constant(1.0), cg1))
@@ -660,3 +667,59 @@ def test_interpolate_composition(mode):
660667
res_adj = assemble(u5)
661668
res_adj2 = assemble(interpolate(TestFunction(V5), conj(TestFunction(V1)) * dx))
662669
assert np.allclose(res_adj.dat.data_ro, res_adj2.dat.data_ro)
670+
671+
672+
@pytest.mark.parallel([1, 3])
673+
def test_interpolate_form():
674+
mesh = UnitSquareMesh(5, 5)
675+
V3 = FunctionSpace(mesh, "CG", 3)
676+
V2 = FunctionSpace(mesh, "CG", 2)
677+
V1 = FunctionSpace(mesh, "CG", 1)
678+
679+
V3_trial = TrialFunction(V3)
680+
V2_test = TestFunction(V2)
681+
V1_test = TestFunction(V1)
682+
V2_dual_trial = TrialFunction(V2.dual())
683+
684+
two_form = inner(V3_trial, V2_test) * dx # V3 x V2 -> R, equiv V3 -> V2^*
685+
interp = interpolate(V1_test, two_form) # V3 x V1 -> R, equiv V3 -> V1^*
686+
assert interp.arguments() == (V1_test, V3_trial)
687+
res1 = assemble(interp)
688+
689+
I = interpolate(V1_test, V2_dual_trial) # V2^* x V1 -> R, equiv V2^* -> V1^*
690+
interp2 = action(I, two_form) # V3 -> V1^*
691+
assert interp2.arguments() == (V1_test, V3_trial)
692+
res2 = assemble(interp2)
693+
assert mat_equals(res1, res2)
694+
695+
res3 = assemble(inner(V3_trial, V1_test) * dx) # V3 x V1 -> R
696+
assert mat_equals(res1, res3)
697+
698+
699+
@pytest.mark.parallel([1, 3])
700+
def test_interpolate_form_mixed():
701+
mesh = UnitSquareMesh(3, 3)
702+
V1 = FunctionSpace(mesh, "CG", 1)
703+
V2 = FunctionSpace(mesh, "CG", 2)
704+
V3 = FunctionSpace(mesh, "CG", 3)
705+
V4 = FunctionSpace(mesh, "CG", 4)
706+
V = V3 * V4
707+
W = V1 * V2
708+
709+
u = TrialFunction(V)
710+
v = TestFunction(V)
711+
q = TestFunction(W)
712+
713+
form = inner(u, v) * dx # V x V -> R, equiv V -> V^*
714+
interp = interpolate(q, form) # V -> W^*, equiv V x W -> R
715+
assert interp.arguments() == (q, u)
716+
res1 = assemble(interp)
717+
718+
I = interpolate(q, TrialFunction(V.dual())) # V^* x W -> R, equiv V^* -> W^*
719+
interp2 = action(I, form) # V -> W^*
720+
assert interp2.arguments() == (q, u)
721+
res2 = assemble(interp2)
722+
assert mat_equals(res1, res2)
723+
724+
res3 = assemble(inner(u, q) * dx) # V x W -> R
725+
assert mat_equals(res1, res3)

0 commit comments

Comments
 (0)