Skip to content

Commit d82039d

Browse files
committed
Split Cofunction
2 parents bb04bb0 + bfb7a19 commit d82039d

File tree

5 files changed

+67
-28
lines changed

5 files changed

+67
-28
lines changed

firedrake/formmanipulation.py

+37-19
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,22 @@
22
import numpy
33
import collections
44

5-
from ufl import as_vector, split, ZeroBaseForm
6-
from ufl.classes import Zero, FixedIndex, ListTensor
5+
from ufl import as_vector, split
6+
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
77
from ufl.algorithms.map_integrands import map_integrand_dags
88
from ufl.algorithms import expand_derivatives
99
from ufl.corealg.map_dag import MultiFunction, map_expr_dags
1010

11+
from pyop2 import MixedDat
12+
from pyop2.utils import as_tuple
13+
1114
from firedrake.petsc import PETSc
1215
from firedrake.ufl_expr import Argument
13-
from firedrake.functionspace import MixedFunctionSpace, FunctionSpace
16+
from firedrake.cofunction import Cofunction
17+
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace
1418

1519

1620
def subspace(V, indices):
17-
try:
18-
indices = tuple(indices)
19-
except TypeError:
20-
# Only one index provided.
21-
indices = (indices, )
2221
if len(indices) == 1:
2322
W = V[indices[0]]
2423
W = FunctionSpace(W.mesh(), W.ufl_element())
@@ -66,7 +65,7 @@ def split(self, form, argument_indices):
6665
"""
6766
args = form.arguments()
6867
self._arg_cache = {}
69-
self.blocks = dict(enumerate(argument_indices))
68+
self.blocks = dict(enumerate(map(as_tuple, argument_indices)))
7069
if len(args) == 0:
7170
# Functional can't be split
7271
return form
@@ -75,11 +74,13 @@ def split(self, form, argument_indices):
7574
assert (idx[0] == 0 for idx in self.blocks.values())
7675
return form
7776
f = map_integrand_dags(self, form)
78-
f = expand_derivatives(f)
79-
if f.empty():
80-
f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), indices),
77+
# TODO find a way to distinguish empty Forms avoiding expand_derivatives
78+
if expand_derivatives(f).empty():
79+
# Get ZeroBaseForm with the right shape
80+
f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(),
81+
self.blocks[arg.number()]),
8182
arg.number(), part=arg.part())
82-
for arg, indices in zip(form.arguments(), argument_indices)))
83+
for arg in form.arguments()))
8384
return f
8485

8586
expr = MultiFunction.reuse_if_untouched
@@ -109,6 +110,7 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds):
109110
@PETSc.Log.EventDecorator()
110111
def argument(self, o):
111112
V = o.function_space()
113+
112114
if len(V) == 1:
113115
# Not on a mixed space, just return ourselves.
114116
return o
@@ -118,12 +120,6 @@ def argument(self, o):
118120

119121
indices = self.blocks[o.number()]
120122

121-
try:
122-
indices = tuple(indices)
123-
except TypeError:
124-
# Only one index provided.
125-
indices = (indices, )
126-
127123
W = subspace(V, indices)
128124
a = Argument(W, o.number(), part=o.part())
129125
a = (a, ) if len(W) == 1 else split(a)
@@ -141,6 +137,28 @@ def argument(self, o):
141137
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
142138
return self._arg_cache.setdefault(o, as_vector(args))
143139

140+
def cofunction(self, o):
141+
V = o.function_space()
142+
143+
if len(V) == 1:
144+
# Not on a mixed space, just return ourselves.
145+
return o
146+
147+
try:
148+
indices, = set(self.blocks.values())
149+
except ValueError:
150+
raise ValueError("Cofunction found on an off-diagonal block")
151+
152+
if len(indices) == 1:
153+
i = indices[0]
154+
W = V[i]
155+
W = DualSpace(W.mesh(), W.ufl_element())
156+
c = Cofunction(W, val=o.dat[i])
157+
else:
158+
W = MixedFunctionSpace([V[i] for i in indices])
159+
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
160+
return c
161+
144162

145163
SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])
146164

scripts/firedrake-install

+12-6
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,21 @@ from glob import iglob
1616
from itertools import chain
1717
import re
1818
import importlib
19+
20+
21+
class InstallError(Exception):
22+
# Exception for generic install problems.
23+
pass
24+
25+
1926
try:
2027
from pkg_resources.extern.packaging.version import Version, InvalidVersion
2128
except ModuleNotFoundError:
22-
from packaging.version import Version, InvalidVersion
29+
try:
30+
from packaging.version import Version, InvalidVersion
31+
except ModuleNotFoundError:
32+
raise InstallError("Neither setuptools or packaging found. Please "
33+
"install one of these packages before trying again.")
2334

2435
osname = platform.uname().system
2536
arch = platform.uname().machine
@@ -52,11 +63,6 @@ firedrake_apps = {
5263
}
5364

5465

55-
class InstallError(Exception):
56-
# Exception for generic install problems.
57-
pass
58-
59-
6066
class FiredrakeConfiguration(dict):
6167
"""A dictionary extended to facilitate the storage of Firedrake
6268
configuration information."""

tests/firedrake/regression/test_linesmoother.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def backend(request):
4343
return request.param
4444

4545

46-
def test_linesmoother(mesh, S1family, expected, backend):
46+
@pytest.mark.parametrize("rhs", ["form_rhs", "cofunc_rhs"])
47+
def test_linesmoother(mesh, S1family, expected, backend, rhs):
4748
base_cell = mesh._base_mesh.ufl_cell()
4849
S2family = "DG" if base_cell.is_simplex() else "DQ"
4950
DGfamily = "DG" if mesh.ufl_cell().is_simplex() else "DQ"
@@ -86,6 +87,10 @@ def test_linesmoother(mesh, S1family, expected, backend):
8687
f = exp(-rsq)
8788

8889
L = inner(f, q)*dx(degree=2*(degree+1))
90+
if rhs == 'cofunc_rhs':
91+
L = assemble(L)
92+
elif rhs != 'form_rhs':
93+
raise ValueError("Unknown right hand side type")
8994

9095
w0 = Function(W)
9196
problem = LinearVariationalProblem(a, L, w0, bcs=bcs, aP=aP, form_compiler_parameters={"mode": "vanilla"})

tests/firedrake/regression/test_matrix_free.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def test_matrixfree_action(a, V, bcs):
130130

131131
@pytest.mark.parametrize("preassembled", [False, True],
132132
ids=["variational", "preassembled"])
133+
@pytest.mark.parametrize("rhs", ["form_rhs", "cofunc_rhs"])
133134
@pytest.mark.parametrize("parameters",
134135
[{"ksp_type": "preonly",
135136
"pc_type": "python",
@@ -168,7 +169,7 @@ def test_matrixfree_action(a, V, bcs):
168169
"fieldsplit_1_fieldsplit_1_pc_type": "python",
169170
"fieldsplit_1_fieldsplit_1_pc_python_type": "firedrake.AssembledPC",
170171
"fieldsplit_1_fieldsplit_1_assembled_pc_type": "lu"}])
171-
def test_fieldsplitting(mesh, preassembled, parameters):
172+
def test_fieldsplitting(mesh, preassembled, parameters, rhs):
172173
V = FunctionSpace(mesh, "CG", 1)
173174
P = FunctionSpace(mesh, "DG", 0)
174175
Q = VectorFunctionSpace(mesh, "DG", 1)
@@ -185,6 +186,10 @@ def test_fieldsplitting(mesh, preassembled, parameters):
185186
a = inner(u, v)*dx
186187

187188
L = inner(expect, v)*dx
189+
if rhs == 'cofunc_rhs':
190+
L = assemble(L)
191+
elif rhs != 'form_rhs':
192+
raise ValueError("Unknown right hand side type")
188193

189194
f = Function(W)
190195

tests/firedrake/regression/test_nullspace.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ def test_nullspace_mixed_multiple_components():
291291

292292
@pytest.mark.parallel(nprocs=2)
293293
@pytest.mark.parametrize("aux_pc", [False, True], ids=["PC(mu)", "PC(DG0-mu)"])
294-
def test_near_nullspace_mixed(aux_pc):
294+
@pytest.mark.parametrize("rhs", ["form_rhs", "cofunc_rhs"])
295+
def test_near_nullspace_mixed(aux_pc, rhs):
295296
# test nullspace and nearnullspace for a mixed Stokes system
296297
# this is tested on the SINKER case of May and Moresi https://doi.org/10.1016/j.pepi.2008.07.036
297298
# fails in parallel if nullspace is copied to fieldsplit_1_Mp_ksp solve (see PR #3488)
@@ -323,6 +324,10 @@ def test_near_nullspace_mixed(aux_pc):
323324

324325
f = as_vector((0, -9.8*conditional(inside_box, 2, 1)))
325326
L = inner(f, v)*dx
327+
if rhs == 'cofunc_rhs':
328+
L = assemble(L)
329+
elif rhs != 'form_rhs':
330+
raise ValueError("Unknown right hand side type")
326331

327332
bcs = [DirichletBC(W[0].sub(0), 0, (1, 2)), DirichletBC(W[0].sub(1), 0, (3, 4))]
328333

0 commit comments

Comments
 (0)