Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fieldsplit: replace empty Forms with ZeroBaseForm #3947

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ jobs:
--install defcon \
--install gadopt \
--install asQ \
--package-branch ufl pbrubeck/simplify-indexed \
|| (cat firedrake-install.log && /bin/false)
- name: Install test dependencies
run: |
Expand Down
13 changes: 8 additions & 5 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
from firedrake.ufl_expr import derivative, adjoint
from ufl import replace


Expand All @@ -11,7 +12,6 @@ def _ad_annotate_init(init):
@no_annotations
@wraps(init)
def wrapper(self, *args, **kwargs):
from firedrake import derivative, adjoint, TrialFunction
init(self, *args, **kwargs)
self._ad_F = self.F
self._ad_u = self.u_restrict
Expand All @@ -20,10 +20,13 @@ def wrapper(self, *args, **kwargs):
try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
dFdu = derivative(self.F,
self.u_restrict,
TrialFunction(self.u_restrict.function_space()))
self._ad_adj_F = adjoint(dFdu)
dFdu = derivative(self.F, self.u_restrict)
try:
self._ad_adj_F = adjoint(dFdu)
except ValueError:
# Try again without expanding derivatives,
# as dFdu might have been simplied to an empty Form
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
except (TypeError, NotImplementedError):
self._ad_adj_F = None
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
Expand Down
22 changes: 13 additions & 9 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
@staticmethod
def update_tensor(assembled_base_form, tensor):
if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)):
assembled_base_form.dat.copy(tensor.dat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.dat.zero()
else:
assembled_base_form.dat.copy(tensor.dat)
elif isinstance(tensor, matrix.MatrixBase):
# Uses the PETSc copy method.
assembled_base_form.petscmat.copy(tensor.petscmat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.petscmat.zeroEntries()
else:
assembled_base_form.petscmat.copy(tensor.petscmat)
else:
raise NotImplementedError("Cannot update tensor of type %s" % type(tensor))

Expand Down Expand Up @@ -1138,7 +1143,7 @@ class OneFormAssembler(ParloopFormAssembler):
Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
1-form.
Notes
Expand Down Expand Up @@ -2127,14 +2132,13 @@ def iter_active_coefficients(form, kinfo):

@staticmethod
def iter_constants(form, kinfo):
"""Yield the form constants"""
"""Yield the form constants referenced in ``kinfo``."""
if isinstance(form, slate.TensorBase):
for const in form.constants():
yield const
all_constants = form.constants()
else:
all_constants = extract_firedrake_constants(form)
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]

@staticmethod
def index_function_spaces(form, indices):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,10 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
return
rank = len(self.f.arguments())
splitter = ExtractSubBlock()
if rank == 1:
form = splitter.split(self.f, argument_indices=(row_field, ))
elif rank == 2:
form = splitter.split(self.f, argument_indices=(row_field, col_field))
form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank])
if isinstance(form, ufl.ZeroBaseForm) or form.empty():
# form is empty, do nothing
return
if u is not None:
form = firedrake.replace(form, {self.u: u})
if action_x is not None:
Expand Down
123 changes: 48 additions & 75 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,30 @@
import numpy
import collections

from ufl import as_vector, FormSum, Form, split
from ufl import as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
from ufl.corealg.map_dag import MultiFunction, map_expr_dags

from pyop2 import MixedDat
from pyop2.utils import as_tuple

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from firedrake.cofunction import Cofunction
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace


def subspace(V, indices):
if len(indices) == 1:
W = V[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
else:
W = MixedFunctionSpace([V[i] for i in indices])
return W


class ExtractSubBlock(MultiFunction):

"""Extract a sub-block from a form."""
Expand All @@ -30,9 +41,11 @@ def indexed(self, o, child, multiindex):
indices = multiindex.indices()
if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
if len(indices) == 1:
return child.ufl_operands[indices[0]._value]
return child[indices[0]]
elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)):
return child
else:
return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices()))
return ListTensor(*(child[i] for i in indices))
return self.expr(o, child, multiindex)

index_inliner = IndexInliner()
Expand All @@ -52,15 +65,22 @@ def split(self, form, argument_indices):
"""
args = form.arguments()
self._arg_cache = {}
self.blocks = dict(enumerate(argument_indices))
self.blocks = dict(enumerate(map(as_tuple, argument_indices)))
if len(args) == 0:
# Functional can't be split
return form
if all(len(a.function_space()) == 1 for a in args):
assert (len(idx) == 1 for idx in self.blocks.values())
assert (idx[0] == 0 for idx in self.blocks.values())
return form
# TODO find a way to distinguish empty Forms avoiding expand_derivatives
f = map_integrand_dags(self, form)
if expand_derivatives(f).empty():
# Get ZeroBaseForm with the right shape
f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(),
self.blocks[arg.number()]),
arg.number(), part=arg.part())
for arg in form.arguments()))
return f

expr = MultiFunction.reuse_if_untouched
Expand Down Expand Up @@ -98,76 +118,42 @@ def argument(self, o):
if o in self._arg_cache:
return self._arg_cache[o]

V_is = V.subfunctions
indices = self.blocks[o.number()]

# Only one index provided.
if isinstance(indices, int):
indices = (indices, )
W = subspace(V, indices)
a = Argument(W, o.number(), part=o.part())
a = (a, ) if len(W) == 1 else split(a)

if len(indices) == 1:
W = V_is[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
a = (Argument(W, o.number(), part=o.part()), )
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
a = split(Argument(W, o.number(), part=o.part()))
args = []
for i in range(len(V_is)):
for i in range(len(V)):
if i in indices:
c = indices.index(i)
a_ = a[c]
if len(a_.ufl_shape) == 0:
args += [a_]
args.append(a_)
else:
args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)]
args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape))
else:
args += [Zero()
for j in numpy.ndindex(V_is[i].value_shape)]
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
return self._arg_cache.setdefault(o, as_vector(args))

def cofunction(self, o):
V = o.function_space()

# Not on a mixed space, just return ourselves.
if len(V) == 1:
# Not on a mixed space, just return ourselves.
return o

# We only need the test space for Cofunction
# We only need the test space for Cofunction
indices = self.blocks[0]
V_is = V.subfunctions

# Only one index provided.
if isinstance(indices, int):
indices = (indices, )

# for two-forms, the cofunction should only
# be returned for the diagonal blocks, so
# if we are asked for an off-diagonal block
# then we return a zero form, analogously to
# the off components of arguments.
if len(self.blocks) == 2:
itest, itrial = self.blocks
on_diag = (itest == itrial)
else:
on_diag = True

# if we are on the diagonal, then return a Cofunction
# in the relevant subspace that points to the data in
# the full space. This means that the right hand side
# of the fieldsplit problem will be correct.
if on_diag:
if len(indices) == 1:
i = indices[0]
W = V_is[i]
W = DualSpace(W.mesh(), W.ufl_element())
c = Cofunction(W, val=o.subfunctions[i].dat)
else:
W = MixedFunctionSpace([V_is[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
if len(indices) == 1:
i = indices[0]
W = V[i]
W = DualSpace(W.mesh(), W.ufl_element())
c = Cofunction(W, val=o.dat[i])
else:
c = ZeroBaseForm(o.arguments())

W = MixedFunctionSpace([V[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
return c


Expand Down Expand Up @@ -207,28 +193,15 @@ def split_form(form, diagonal=False):
args = form.arguments()
shape = tuple(len(a.function_space()) for a in args)
forms = []
rank = len(shape)
if diagonal:
assert len(shape) == 2
assert rank == 2
rank = 1
for idx in numpy.ndindex(shape):
if diagonal:
i, j = idx
if i != j:
continue
f = splitter.split(form, idx)

# does f actually contain anything?
if isinstance(f, Cofunction):
flen = 1
elif isinstance(f, FormSum):
flen = len(f.components())
elif isinstance(f, Form):
flen = len(f.integrals())
else:
raise ValueError(
"ExtractSubBlock.split should have returned an instance of "
"either Form, FormSum, or Cofunction")

if flen > 0:
if diagonal:
i, j = idx
if i != j:
continue
idx = (i, )
forms.append(SplitForm(indices=idx, form=f))
forms.append(SplitForm(indices=idx[:rank], form=f))
return tuple(forms)
28 changes: 17 additions & 11 deletions firedrake/matrix_free/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from firedrake.bcs import DirichletBC, EquationBCSplit
from firedrake.petsc import PETSc
from firedrake.utils import cached_property
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from ufl.form import ZeroBaseForm


__all__ = ("ImplicitMatrixContext", )
Expand Down Expand Up @@ -107,23 +110,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

# create functions from test and trial space to help
# with 1-form assembly
test_space, trial_space = [
a.arguments()[i].function_space() for i in (0, 1)
]
from firedrake import function, cofunction
test_space, trial_space = (
arg.function_space() for arg in a.arguments()
)
# Need a cofunction since y receives the assembled result of Ax
self._ystar = cofunction.Cofunction(test_space.dual())
self._y = function.Function(test_space)
self._x = function.Function(trial_space)
self._xstar = cofunction.Cofunction(trial_space.dual())
self._ystar = Cofunction(test_space.dual())
self._y = Function(test_space)
self._x = Function(trial_space)
self._xstar = Cofunction(trial_space.dual())

# These are temporary storage for holding the BC
# values during matvec application. _xbc is for
# the action and ._ybc is for transpose.
if len(self.bcs) > 0:
self._xbc = cofunction.Cofunction(trial_space.dual())
self._xbc = Cofunction(trial_space.dual())
if len(self.col_bcs) > 0:
self._ybc = cofunction.Cofunction(test_space.dual())
self._ybc = Cofunction(test_space.dual())

# Get size information from template vecs on test and trial spaces
trial_vec = trial_space.dof_dset.layout_vec
Expand All @@ -135,6 +137,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

self.action = action(self.a, self._x)
self.actionT = action(self.aT, self._y)
# TODO prevent action from returning empty Forms
if self.action.empty():
self.action = ZeroBaseForm(self.a.arguments()[:-1])
if self.actionT.empty():
self.actionT = ZeroBaseForm(self.aT.arguments()[:-1])

# For assembling action(f, self._x)
self.bcs_action = []
Expand Down Expand Up @@ -170,7 +177,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[],

@cached_property
def _diagonal(self):
from firedrake import Cofunction
assert self.on_diag
return Cofunction(self._x.function_space().dual())

Expand Down
2 changes: 1 addition & 1 deletion firedrake/preconditioners/massinv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MassInvPC(AssembledPC):
context, keyed on ``"mu"``.
"""
def form(self, pc, test, trial):
_, bcs = super(MassInvPC, self).form(pc, test, trial)
_, bcs = super(MassInvPC, self).form(pc)

appctx = self.get_appctx(pc)
mu = appctx.get("mu", 1.0)
Expand Down
Loading
Loading