From 22865961201f2a68cab66ddcaea14f0742c5f071 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sun, 29 Dec 2024 22:03:29 -0600 Subject: [PATCH 01/11] DO NOT MERGE --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0eb616c24d..fa437c82e5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,6 +84,7 @@ jobs: --install defcon \ --install gadopt \ --install asQ \ + --package-branch ufl pbrubeck/simplify-indexed \ || (cat firedrake-install.log && /bin/false) - name: Install test dependencies run: | From bb04bb00860b9f697df226b1ec1ce35f35e4c2bb Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 31 Dec 2024 20:01:24 -0600 Subject: [PATCH 02/11] Replace empty Jacobians with ZeroBaseForm --- firedrake/adjoint_utils/variational_solver.py | 8 +-- firedrake/assemble.py | 11 ++- firedrake/formmanipulation.py | 67 +++++++++++-------- firedrake/preconditioners/massinv.py | 2 +- firedrake/solving_utils.py | 9 ++- firedrake/tsfc_interface.py | 4 +- .../firedrake/slate/test_assemble_tensors.py | 17 +++-- 7 files changed, 70 insertions(+), 48 deletions(-) diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index c90d2668e0..79eb09096e 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -2,6 +2,7 @@ from functools import wraps from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock +from firedrake.ufl_expr import derivative, adjoint from ufl import replace @@ -11,7 +12,6 @@ def _ad_annotate_init(init): @no_annotations @wraps(init) def wrapper(self, *args, **kwargs): - from firedrake import derivative, adjoint, TrialFunction init(self, *args, **kwargs) self._ad_F = self.F self._ad_u = self.u_restrict @@ -20,10 +20,8 @@ def wrapper(self, *args, **kwargs): try: # Some forms (e.g. SLATE tensors) are not currently # differentiable. - dFdu = derivative(self.F, - self.u_restrict, - TrialFunction(self.u_restrict.function_space())) - self._ad_adj_F = adjoint(dFdu) + dFdu = derivative(self.F, self.u_restrict) + self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) except (TypeError, NotImplementedError): self._ad_adj_F = None self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear} diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f3049ae01c..60c934b6c7 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args): @staticmethod def update_tensor(assembled_base_form, tensor): if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): - assembled_base_form.dat.copy(tensor.dat) + if isinstance(assembled_base_form, ufl.ZeroBaseForm): + tensor.dat.zero() + else: + assembled_base_form.dat.copy(tensor.dat) elif isinstance(tensor, matrix.MatrixBase): - # Uses the PETSc copy method. - assembled_base_form.petscmat.copy(tensor.petscmat) + if isinstance(assembled_base_form, ufl.ZeroBaseForm): + tensor.petscmat.zero() + else: + assembled_base_form.petscmat.copy(tensor.petscmat) else: raise NotImplementedError("Cannot update tensor of type %s" % type(tensor)) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 35a6789107..3179961df8 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -2,13 +2,29 @@ import numpy import collections -from ufl import as_vector +from ufl import as_vector, split, ZeroBaseForm from ufl.classes import Zero, FixedIndex, ListTensor from ufl.algorithms.map_integrands import map_integrand_dags +from ufl.algorithms import expand_derivatives from ufl.corealg.map_dag import MultiFunction, map_expr_dags from firedrake.petsc import PETSc from firedrake.ufl_expr import Argument +from firedrake.functionspace import MixedFunctionSpace, FunctionSpace + + +def subspace(V, indices): + try: + indices = tuple(indices) + except TypeError: + # Only one index provided. + indices = (indices, ) + if len(indices) == 1: + W = V[indices[0]] + W = FunctionSpace(W.mesh(), W.ufl_element()) + else: + W = MixedFunctionSpace([V[i] for i in indices]) + return W class ExtractSubBlock(MultiFunction): @@ -26,9 +42,11 @@ def indexed(self, o, child, multiindex): indices = multiindex.indices() if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices): if len(indices) == 1: - return child.ufl_operands[indices[0]._value] + return child[indices[0]] + elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)): + return child else: - return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices())) + return ListTensor(*(child[i] for i in indices)) return self.expr(o, child, multiindex) index_inliner = IndexInliner() @@ -57,6 +75,11 @@ def split(self, form, argument_indices): assert (idx[0] == 0 for idx in self.blocks.values()) return form f = map_integrand_dags(self, form) + f = expand_derivatives(f) + if f.empty(): + f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), indices), + arg.number(), part=arg.part()) + for arg, indices in zip(form.arguments(), argument_indices))) return f expr = MultiFunction.reuse_if_untouched @@ -85,8 +108,6 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds): @PETSc.Log.EventDecorator() def argument(self, o): - from ufl import split - from firedrake import MixedFunctionSpace, FunctionSpace V = o.function_space() if len(V) == 1: # Not on a mixed space, just return ourselves. @@ -95,36 +116,29 @@ def argument(self, o): if o in self._arg_cache: return self._arg_cache[o] - V_is = V.subfunctions indices = self.blocks[o.number()] try: indices = tuple(indices) - nidx = len(indices) except TypeError: # Only one index provided. indices = (indices, ) - nidx = 1 - if nidx == 1: - W = V_is[indices[0]] - W = FunctionSpace(W.mesh(), W.ufl_element()) - a = (Argument(W, o.number(), part=o.part()), ) - else: - W = MixedFunctionSpace([V_is[i] for i in indices]) - a = split(Argument(W, o.number(), part=o.part())) + W = subspace(V, indices) + a = Argument(W, o.number(), part=o.part()) + a = (a, ) if len(W) == 1 else split(a) + args = [] - for i in range(len(V_is)): + for i in range(len(V)): if i in indices: c = indices.index(i) a_ = a[c] if len(a_.ufl_shape) == 0: - args += [a_] + args.append(a_) else: - args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)] + args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape)) else: - args += [Zero() - for j in numpy.ndindex(V_is[i].value_shape)] + args.extend(Zero() for j in numpy.ndindex(V[i].value_shape)) return self._arg_cache.setdefault(o, as_vector(args)) @@ -168,11 +182,10 @@ def split_form(form, diagonal=False): assert len(shape) == 2 for idx in numpy.ndindex(shape): f = splitter.split(form, idx) - if len(f.integrals()) > 0: - if diagonal: - i, j = idx - if i != j: - continue - idx = (i, ) - forms.append(SplitForm(indices=idx, form=f)) + if diagonal: + i, j = idx + if i != j: + continue + idx = (i, ) + forms.append(SplitForm(indices=idx, form=f)) return tuple(forms) diff --git a/firedrake/preconditioners/massinv.py b/firedrake/preconditioners/massinv.py index 92f286c708..d29c704e8b 100644 --- a/firedrake/preconditioners/massinv.py +++ b/firedrake/preconditioners/massinv.py @@ -20,7 +20,7 @@ class MassInvPC(AssembledPC): context, keyed on ``"mu"``. """ def form(self, pc, test, trial): - _, bcs = super(MassInvPC, self).form(pc, test, trial) + _, bcs = super(MassInvPC, self).form(pc) appctx = self.get_appctx(pc) mu = appctx.get("mu", 1.0) diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index 9e843016b5..789a6f1880 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -12,8 +12,8 @@ def _make_reasons(reasons): - return dict([(getattr(reasons, r), r) - for r in dir(reasons) if not r.startswith('_')]) + return {getattr(reasons, r): r + for r in dir(reasons) if not r.startswith('_')} KSPReasons = _make_reasons(PETSc.KSP.ConvergedReason()) @@ -333,7 +333,7 @@ def split(self, fields): # Split it apart to shove in the form. subsplit = split(subu) # Permutation from field indexing to indexing of pieces - field_renumbering = dict([f, i] for i, f in enumerate(field)) + field_renumbering = {f: i for i, f in enumerate(field)} vec = [] for i, u in enumerate(us): if i in field: @@ -344,8 +344,7 @@ def split(self, fields): if u.ufl_shape == (): vec.append(u) else: - for idx in numpy.ndindex(u.ufl_shape): - vec.append(u[idx]) + vec.extend(u[idx] for idx in numpy.ndindex(u.ufl_shape)) # So now we have a new representation for the solution # vector in the old problem. For the fields we're going diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index ba10d79507..1117f54bd4 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -11,7 +11,7 @@ import ufl import finat.ufl -from ufl import Form, conj +from ufl import conj, Form, ZeroBaseForm from .ufl_expr import TestFunction from tsfc import compile_form as original_tsfc_compile_form @@ -203,7 +203,7 @@ def compile_form(form, name, parameters=None, split=True, interface=None, diagon iterable = ([(None, )*nargs, form], ) for idx, f in iterable: f = _real_mangle(f) - if not f.integrals(): + if isinstance(f, ZeroBaseForm) or f.empty(): # If we're assembling the R space component of a mixed argument, # and that component doesn't actually appear in the form then we # have an empty form, which we should not attempt to assemble. diff --git a/tests/firedrake/slate/test_assemble_tensors.py b/tests/firedrake/slate/test_assemble_tensors.py index 5aff159b9b..c35d43e27e 100644 --- a/tests/firedrake/slate/test_assemble_tensors.py +++ b/tests/firedrake/slate/test_assemble_tensors.py @@ -249,9 +249,13 @@ def test_matrix_subblocks(mesh): refs = dict(split_form(A.form)) _A = A.blocks for x, y in indices: - ref = assemble(refs[x, y]).M.values block = _A[x, y] - assert np.allclose(assemble(block).M.values, ref, rtol=1e-14) + ref = refs[x, y] + if isinstance(ref, Form): + assert np.allclose(assemble(block).M.values, + assemble(ref).M.values, rtol=1e-14) + elif isinstance(ref, ZeroBaseForm): + assert block.form == ref # Mixed blocks A0101 = _A[:2, :2] @@ -280,9 +284,12 @@ def test_matrix_subblocks(mesh): (A1212_10, refs[(2, 1)])] # Test assembly of blocks of mixed blocks - for tensor, form in items: - ref = assemble(form).M.values - assert np.allclose(assemble(tensor).M.values, ref, rtol=1e-14) + for block, ref in items: + if isinstance(ref, Form): + assert np.allclose(assemble(block).M.values, + assemble(ref).M.values, rtol=1e-14) + elif isinstance(ref, ZeroBaseForm): + assert block.form == ref def test_diagonal(mass, matrix_mixed_nofacet): From af53302c7c9c3eeb185887e779b6cca0bead02e2 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Jan 2025 10:07:05 -0600 Subject: [PATCH 03/11] Do not split off-diagonal blocks if we only want the diagonal --- firedrake/formmanipulation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 708adfd8e5..114651f793 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -196,14 +196,15 @@ def split_form(form, diagonal=False): args = form.arguments() shape = tuple(len(a.function_space()) for a in args) forms = [] + arity = len(shape) if diagonal: - assert len(shape) == 2 + assert arity == 2 + arity = 1 for idx in numpy.ndindex(shape): - f = splitter.split(form, idx) if diagonal: i, j = idx if i != j: continue - idx = (i, ) - forms.append(SplitForm(indices=idx, form=f)) + f = splitter.split(form, idx) + forms.append(SplitForm(indices=idx[:arity], form=f)) return tuple(forms) From 7f40504b440d8735c414f5f919c083be877e8da7 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Jan 2025 20:29:13 -0600 Subject: [PATCH 04/11] Zero-simplify slate Tensors --- firedrake/slate/slac/tsfc_driver.py | 1 + firedrake/slate/slate.py | 45 ++++++++++++++++++++++------- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/firedrake/slate/slac/tsfc_driver.py b/firedrake/slate/slac/tsfc_driver.py index 0f5fbf96d3..136b2a0084 100644 --- a/firedrake/slate/slac/tsfc_driver.py +++ b/firedrake/slate/slac/tsfc_driver.py @@ -50,6 +50,7 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None): assert tensor.terminal, ( "Only terminal tensors have forms associated with them!" ) + # Sets a default name for the subkernel prefix. mapper = RemoveRestrictions() integrals = map(partial(map_integrand_dags, mapper), diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index 85b7af3635..fd9535c31a 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -32,7 +32,7 @@ from ufl.corealg.multifunction import MultiFunction from ufl.classes import Zero from ufl.domain import join_domains, sort_domains -from ufl.form import Form +from ufl.form import Form, ZeroBaseForm import hashlib from firedrake.formmanipulation import ExtractSubBlock @@ -237,7 +237,7 @@ def coeff_map(self): coeff_map[m].update(c.indices[0]) else: m = self.coefficients().index(c) - split_map = tuple(range(len(c.subfunctions))) if isinstance(c, Function) or isinstance(c, Constant) or isinstance(c, Cofunction) else tuple(range(1)) + split_map = tuple(range(len(c.subfunctions))) if isinstance(c, (Function, Constant, Cofunction)) else (0,) coeff_map[m].update(split_map) return tuple((k, tuple(sorted(v)))for k, v in coeff_map.items()) @@ -382,6 +382,10 @@ def __eq__(self, other): """Determines whether two TensorBase objects are equal using their associated keys. """ + if isinstance(other, (int, float)) and other == 0: + if isinstance(self, Tensor): + return isinstance(self.form, ZeroBaseForm) or self.form.empty() + return False return self._key == other._key def __ne__(self, other): @@ -650,7 +654,7 @@ def __init__(self, tensor, indices): """Constructor for the Block class.""" super(Block, self).__init__() self.operands = (tensor,) - self._blocks = dict(enumerate(indices)) + self._blocks = dict(enumerate(map(as_tuple, indices))) self._indices = indices @cached_property @@ -671,14 +675,12 @@ def _split_arguments(self): nargs = [] for i, arg in enumerate(tensor.arguments()): V = arg.function_space() - V_is = V.subfunctions - idx = as_tuple(self._blocks[i]) + idx = self._blocks[i] if len(idx) == 1: - fidx, = idx - W = V_is[fidx] + W = V[idx[0]] W = FunctionSpace(W.mesh(), W.ufl_element()) else: - W = MixedFunctionSpace([V_is[fidx] for fidx in idx]) + W = MixedFunctionSpace([V[fidx] for fidx in idx]) nargs.append(Argument(W, arg.number(), part=arg.part())) @@ -880,7 +882,7 @@ class Tensor(TensorBase): def __init__(self, form, diagonal=False): """Constructor for the Tensor class.""" - if not isinstance(form, Form): + if not isinstance(form, (Form, ZeroBaseForm)): if isinstance(form, Function): raise TypeError("Use AssembledVector instead of Tensor.") raise TypeError("Only UFL forms are acceptable inputs.") @@ -1103,6 +1105,10 @@ def _output_string(self, prec=None): class Transpose(UnaryOp): """An abstract Slate class representing the transpose of a tensor.""" + def __new__(cls, A): + if A == 0: + return Tensor(ZeroBaseForm(A.form.arguments()[::-1])) + return BinaryOp.__new__(cls) @cached_property def arg_function_spaces(self): @@ -1127,6 +1133,10 @@ def _output_string(self, prec=None): class Negative(UnaryOp): """Abstract Slate class representing the negation of a tensor object.""" + def __new__(cls, A): + if A == 0: + return A + return BinaryOp.__new__(cls) @cached_property def arg_function_spaces(self): @@ -1197,6 +1207,12 @@ class Add(BinaryOp): :arg A: a :class:`~.firedrake.slate.TensorBase` object. :arg B: another :class:`~.firedrake.slate.TensorBase` object. """ + def __new__(cls, A, B): + if A == 0: + return B + elif B == 0: + return A + return BinaryOp.__new__(cls) def __init__(self, A, B): """Constructor for the Add class.""" @@ -1238,6 +1254,10 @@ class Mul(BinaryOp): :arg A: a :class:`~.firedrake.slate.TensorBase` object. :arg B: another :class:`~.firedrake.slate.TensorBase` object. """ + def __new__(cls, A, B): + if A == 0 or B == 0: + return Tensor(ZeroBaseForm(A.arguments()[:-1] + B.arguments()[1:])) + return BinaryOp.__new__(cls) def __init__(self, A, B): """Constructor for the Mul class.""" @@ -1295,7 +1315,7 @@ def __new__(cls, A, B, decomposition=None): raise ValueError("Illegal op on a %s-tensor with a %s-tensor." % (A.shape, B.shape)) - fsA = A.arg_function_spaces[::-1][-1] + fsA = A.arg_function_spaces[0] fsB = B.arg_function_spaces[0] assert space_equivalence(fsA, fsB), ( @@ -1348,6 +1368,11 @@ class DiagonalTensor(UnaryOp): """ diagonal = True + def __new__(cls, A): + if A == 0: + return Tensor(ZeroBaseForm(A.arguments()[:1])) + return BinaryOp.__new__(cls) + def __init__(self, A): """Constructor for the Diagonal class.""" assert A.rank == 2, "The tensor must be rank 2." From 3d06fc56e58302c9ff179eb9890986fe9e4f50ee Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 3 Jan 2025 11:26:48 -0600 Subject: [PATCH 05/11] ImplicitMatrixContext: handle empty action --- firedrake/assemble.py | 13 ++++++++----- firedrake/matrix_free/operators.py | 28 +++++++++++++++++----------- firedrake/slate/slate.py | 19 +++++++++++-------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 60c934b6c7..88d00c6db8 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -311,7 +311,8 @@ def __init__(self, zero_bc_nodes=False, diagonal=False, weight=1.0, - allocation_integral_types=None): + allocation_integral_types=None, + needs_zeroing=False): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters) self._mat_type = mat_type self._sub_mat_type = sub_mat_type @@ -321,6 +322,7 @@ def __init__(self, self._diagonal = diagonal self._weight = weight self._allocation_integral_types = allocation_integral_types + assert not needs_zeroing def allocate(self): rank = len(self._form.arguments()) @@ -1127,7 +1129,8 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - pass + if not isinstance(tensor, op2.Global): + raise TypeError(f"Expecting a op2.Global, got {tensor!r}.") @staticmethod def _as_pyop2_type(tensor, indices=None): @@ -1143,7 +1146,7 @@ class OneFormAssembler(ParloopFormAssembler): Parameters ---------- - form : ufl.Form or slate.TensorBasehe + form : ufl.Form or slate.TensorBase 1-form. Notes @@ -1189,8 +1192,8 @@ def _apply_bc(self, tensor, bc): self._apply_dirichlet_bc(tensor, bc) elif isinstance(bc, EquationBCSplit): bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) else: raise AssertionError diff --git a/firedrake/matrix_free/operators.py b/firedrake/matrix_free/operators.py index 3ee448730e..b111cbde76 100644 --- a/firedrake/matrix_free/operators.py +++ b/firedrake/matrix_free/operators.py @@ -10,6 +10,9 @@ from firedrake.bcs import DirichletBC, EquationBCSplit from firedrake.petsc import PETSc from firedrake.utils import cached_property +from firedrake.function import Function +from firedrake.cofunction import Cofunction +from ufl.form import ZeroBaseForm __all__ = ("ImplicitMatrixContext", ) @@ -107,23 +110,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[], # create functions from test and trial space to help # with 1-form assembly - test_space, trial_space = [ - a.arguments()[i].function_space() for i in (0, 1) - ] - from firedrake import function, cofunction + test_space, trial_space = ( + arg.function_space() for arg in a.arguments() + ) # Need a cofunction since y receives the assembled result of Ax - self._ystar = cofunction.Cofunction(test_space.dual()) - self._y = function.Function(test_space) - self._x = function.Function(trial_space) - self._xstar = cofunction.Cofunction(trial_space.dual()) + self._ystar = Cofunction(test_space.dual()) + self._y = Function(test_space) + self._x = Function(trial_space) + self._xstar = Cofunction(trial_space.dual()) # These are temporary storage for holding the BC # values during matvec application. _xbc is for # the action and ._ybc is for transpose. if len(self.bcs) > 0: - self._xbc = cofunction.Cofunction(trial_space.dual()) + self._xbc = Cofunction(trial_space.dual()) if len(self.col_bcs) > 0: - self._ybc = cofunction.Cofunction(test_space.dual()) + self._ybc = Cofunction(test_space.dual()) # Get size information from template vecs on test and trial spaces trial_vec = trial_space.dof_dset.layout_vec @@ -135,6 +137,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[], self.action = action(self.a, self._x) self.actionT = action(self.aT, self._y) + # TODO prevent action from returning empty Forms + if self.action.empty(): + self.action = ZeroBaseForm(self.a.arguments()[:-1]) + if self.actionT.empty(): + self.actionT = ZeroBaseForm(self.aT.arguments()[:-1]) # For assembling action(f, self._x) self.bcs_action = [] @@ -170,7 +177,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[], @cached_property def _diagonal(self): - from firedrake import Cofunction assert self.on_diag return Cofunction(self._x.function_space().dual()) diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index fd9535c31a..1a8c792414 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -21,7 +21,10 @@ from ufl import Constant from ufl.coefficient import BaseCoefficient +from firedrake.formmanipulation import ExtractSubBlock from firedrake.function import Function, Cofunction +from firedrake.functionspace import FunctionSpace, MixedFunctionSpace +from firedrake.ufl_expr import Argument, TestFunction from firedrake.utils import cached_property, unique from itertools import chain, count @@ -35,8 +38,6 @@ from ufl.form import Form, ZeroBaseForm import hashlib -from firedrake.formmanipulation import ExtractSubBlock - from tsfc.ufl_utils import extract_firedrake_constants @@ -293,6 +294,10 @@ def solve(self, B, decomposition=None): """ return Solve(self, B, decomposition=decomposition) + def empty(self): + """Returns whether the form associated with the tensor is empty.""" + return False + @cached_property def blocks(self): """Returns an object containing the blocks of the tensor defined @@ -461,8 +466,6 @@ def arg_function_spaces(self): @cached_property def _argument(self): """Generates a 'test function' associated with this class.""" - from firedrake.ufl_expr import TestFunction - V, = self.arg_function_spaces return TestFunction(V) @@ -543,7 +546,6 @@ def arg_function_spaces(self): @cached_property def _argument(self): """Generates a tuple of 'test function' associated with this class.""" - from firedrake.ufl_expr import TestFunction return tuple(TestFunction(fs) for fs in self.arg_function_spaces) def arguments(self): @@ -668,9 +670,6 @@ def _split_arguments(self): """Splits the function space and stores the component spaces determined by the indices. """ - from firedrake.functionspace import FunctionSpace, MixedFunctionSpace - from firedrake.ufl_expr import Argument - tensor, = self.operands nargs = [] for i, arg in enumerate(tensor.arguments()): @@ -938,6 +937,10 @@ def subdomain_data(self): """ return self.form.subdomain_data() + def empty(self): + """Returns whether the form associated with the tensor is empty.""" + return self.form.empty() + def _output_string(self, prec=None): """Creates a string representation of the tensor.""" return ["S", "V", "M"][self.rank] + "_%d" % self.id From 6078f93243c1655d923c6aa7fc81e2fdeffd0e8d Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 3 Jan 2025 20:27:42 -0600 Subject: [PATCH 06/11] Only extract constants referenced in the kernel --- firedrake/assemble.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 88d00c6db8..dafe2a32ab 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -585,7 +585,7 @@ def update_tensor(assembled_base_form, tensor): assembled_base_form.dat.copy(tensor.dat) elif isinstance(tensor, matrix.MatrixBase): if isinstance(assembled_base_form, ufl.ZeroBaseForm): - tensor.petscmat.zero() + tensor.petscmat.zeroEntries() else: assembled_base_form.petscmat.copy(tensor.petscmat) else: @@ -2135,14 +2135,13 @@ def iter_active_coefficients(form, kinfo): @staticmethod def iter_constants(form, kinfo): - """Yield the form constants""" + """Yield the form constants referenced in ``kinfo``.""" if isinstance(form, slate.TensorBase): - for const in form.constants(): - yield const + all_constants = form.constants() else: all_constants = extract_firedrake_constants(form) - for constant_index in kinfo.constant_numbers: - yield all_constants[constant_index] + for constant_index in kinfo.constant_numbers: + yield all_constants[constant_index] @staticmethod def index_function_spaces(form, indices): From 5894b490cecb51532dab35d13b243fa8b4507f44 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Fri, 3 Jan 2025 20:37:40 -0600 Subject: [PATCH 07/11] Adjoint: only skip expand_derivatives if necessary --- firedrake/adjoint_utils/variational_solver.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index 79eb09096e..c191308adc 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -21,7 +21,12 @@ def wrapper(self, *args, **kwargs): # Some forms (e.g. SLATE tensors) are not currently # differentiable. dFdu = derivative(self.F, self.u_restrict) - self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) + try: + self._ad_adj_F = adjoint(dFdu) + except ValueError: + # Try again without expanding derivatives, + # as dFdu might have been simplied to an empty Form + self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) except (TypeError, NotImplementedError): self._ad_adj_F = None self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear} From d99ba50b42874f0e61381c5c6ff679064287616b Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 4 Jan 2025 17:48:35 -0600 Subject: [PATCH 08/11] style --- firedrake/formmanipulation.py | 15 ++++++--------- firedrake/slate/slac/tsfc_driver.py | 1 - 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 114651f793..5e92bd8e8a 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -144,11 +144,8 @@ def cofunction(self, o): # Not on a mixed space, just return ourselves. return o - try: - indices, = set(self.blocks.values()) - except ValueError: - raise ValueError("Cofunction found on an off-diagonal block") - + # We only need the test space for Cofunction  + indices = self.blocks[0] if len(indices) == 1: i = indices[0] W = V[i] @@ -196,15 +193,15 @@ def split_form(form, diagonal=False): args = form.arguments() shape = tuple(len(a.function_space()) for a in args) forms = [] - arity = len(shape) + rank = len(shape) if diagonal: - assert arity == 2 - arity = 1 + assert rank == 2 + rank = 1 for idx in numpy.ndindex(shape): if diagonal: i, j = idx if i != j: continue f = splitter.split(form, idx) - forms.append(SplitForm(indices=idx[:arity], form=f)) + forms.append(SplitForm(indices=idx[:rank], form=f)) return tuple(forms) diff --git a/firedrake/slate/slac/tsfc_driver.py b/firedrake/slate/slac/tsfc_driver.py index 136b2a0084..0f5fbf96d3 100644 --- a/firedrake/slate/slac/tsfc_driver.py +++ b/firedrake/slate/slac/tsfc_driver.py @@ -50,7 +50,6 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None): assert tensor.terminal, ( "Only terminal tensors have forms associated with them!" ) - # Sets a default name for the subkernel prefix. mapper = RemoveRestrictions() integrals = map(partial(map_integrand_dags, mapper), From d6bb7dd76ae1fed2a6b4ff6ccb5195516d104869 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Sat, 4 Jan 2025 20:54:36 -0600 Subject: [PATCH 09/11] EquationBC: do not reconstruct empty Forms --- firedrake/assemble.py | 11 ++++------- firedrake/bcs.py | 8 ++++---- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index dafe2a32ab..61909e9955 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -311,8 +311,7 @@ def __init__(self, zero_bc_nodes=False, diagonal=False, weight=1.0, - allocation_integral_types=None, - needs_zeroing=False): + allocation_integral_types=None): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters) self._mat_type = mat_type self._sub_mat_type = sub_mat_type @@ -322,7 +321,6 @@ def __init__(self, self._diagonal = diagonal self._weight = weight self._allocation_integral_types = allocation_integral_types - assert not needs_zeroing def allocate(self): rank = len(self._form.arguments()) @@ -1129,8 +1127,7 @@ def _apply_bc(self, tensor, bc): pass def _check_tensor(self, tensor): - if not isinstance(tensor, op2.Global): - raise TypeError(f"Expecting a op2.Global, got {tensor!r}.") + pass @staticmethod def _as_pyop2_type(tensor, indices=None): @@ -1192,8 +1189,8 @@ def _apply_bc(self, tensor, bc): self._apply_dirichlet_bc(tensor, bc) elif isinstance(bc, EquationBCSplit): bc.zero(tensor) - get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) else: raise AssertionError diff --git a/firedrake/bcs.py b/firedrake/bcs.py index f0d007ede4..7c6821b3e3 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -634,10 +634,10 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col return rank = len(self.f.arguments()) splitter = ExtractSubBlock() - if rank == 1: - form = splitter.split(self.f, argument_indices=(row_field, )) - elif rank == 2: - form = splitter.split(self.f, argument_indices=(row_field, col_field)) + form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank]) + if form == 0: + # form is empty, do nothing + return if u is not None: form = firedrake.replace(form, {self.u: u}) if action_x is not None: From ed584675ee3c795ae4b6ea6606dadb745515cc03 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 6 Jan 2025 17:33:43 -0600 Subject: [PATCH 10/11] lower degree for EquationBC tests --- .../equation_bcs/test_equation_bcs.py | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/tests/firedrake/equation_bcs/test_equation_bcs.py b/tests/firedrake/equation_bcs/test_equation_bcs.py index 087b07aa36..3929eeaddb 100644 --- a/tests/firedrake/equation_bcs/test_equation_bcs.py +++ b/tests/firedrake/equation_bcs/test_equation_bcs.py @@ -17,15 +17,13 @@ def nonlinear_poisson(solver_parameters, mesh_num, porder): u = Function(V) v = TestFunction(V) - f = Function(V) x, y = SpatialCoordinate(mesh) - f.interpolate(- 8.0 * pi * pi * cos(x * pi * 2) * cos(y * pi * 2)) + f = - 8.0 * pi * pi * cos(x * pi * 2) * cos(y * pi * 2) a = - inner(grad(u), grad(v)) * dx L = inner(f, v) * dx - g = Function(V) - g.interpolate(cos(2 * pi * x) * cos(2 * pi * y)) + g = cos(2 * pi * x) * cos(2 * pi * y) # Equivalent to bc1 = EquationBC(v * (u - g1) * ds(1) == 0, u, 1) e2 = as_vector([0., 1.]) @@ -33,7 +31,7 @@ def nonlinear_poisson(solver_parameters, mesh_num, porder): solve(a - L == 0, u, bcs=[bc1], solver_parameters=solver_parameters) - f.interpolate(cos(x * pi * 2) * cos(y * pi * 2)) + f = cos(x * pi * 2) * cos(y * pi * 2) return sqrt(assemble(inner(u - f, u - f) * dx)) @@ -46,15 +44,13 @@ def linear_poisson(solver_parameters, mesh_num, porder): u = TrialFunction(V) v = TestFunction(V) - f = Function(V) x, y = SpatialCoordinate(mesh) - f.interpolate(- 8.0 * pi * pi * cos(x * pi * 2) * cos(y * pi * 2)) + f = - 8.0 * pi * pi * cos(x * pi * 2) * cos(y * pi * 2) a = - inner(grad(u), grad(v)) * dx L = inner(f, v) * dx - g = Function(V) - g.interpolate(cos(2 * pi * x) * cos(2 * pi * y)) + g = cos(2 * pi * x) * cos(2 * pi * y) u_ = Function(V) @@ -62,7 +58,7 @@ def linear_poisson(solver_parameters, mesh_num, porder): solve(a == L, u_, bcs=[bc1], solver_parameters=solver_parameters) - f.interpolate(cos(x * pi * 2) * cos(y * pi * 2)) + f = cos(x * pi * 2) * cos(y * pi * 2) return sqrt(assemble(inner(u_ - f, u_ - f) * dx)) @@ -75,9 +71,8 @@ def nonlinear_poisson_bbc(solver_parameters, mesh_num, porder): u = Function(V) v = TestFunction(V) - f = Function(V) x, y = SpatialCoordinate(mesh) - f.interpolate(- 8.0 * pi * pi * cos(x * pi * 2)*cos(y * pi * 2)) + f = - 8.0 * pi * pi * cos(x * pi * 2)*cos(y * pi * 2) a = - inner(grad(u), grad(v)) * dx L = inner(f, v) * dx @@ -85,13 +80,13 @@ def nonlinear_poisson_bbc(solver_parameters, mesh_num, porder): e2 = as_vector([0., 1.]) a1 = (-inner(dot(grad(u), e2), dot(grad(v), e2)) + 4 * pi * pi * inner(u, v)) * ds(1) - g = Function(V).interpolate(cos(2 * pi * x) * cos(2 * pi * y)) + g = cos(2 * pi * x) * cos(2 * pi * y) bbc = DirichletBC(V, g, ((1, 3), (1, 4))) bc1 = EquationBC(a1 == 0, u, 1, bcs=[bbc]) solve(a - L == 0, u, bcs=[bc1], solver_parameters=solver_parameters) - f.interpolate(cos(x * pi * 2) * cos(y * pi * 2)) + f = cos(x * pi * 2) * cos(y * pi * 2) return sqrt(assemble(inner(u - f, u - f) * dx)) @@ -104,9 +99,8 @@ def linear_poisson_bbc(solver_parameters, mesh_num, porder): u = TrialFunction(V) v = TestFunction(V) - f = Function(V) x, y = SpatialCoordinate(mesh) - f.interpolate(- 8.0 * pi * pi * cos(x * pi * 2)*cos(y * pi * 2)) + f = - 8.0 * pi * pi * cos(x * pi * 2)*cos(y * pi * 2) a = - inner(grad(u), grad(v)) * dx L = inner(f, v) * dx @@ -117,13 +111,13 @@ def linear_poisson_bbc(solver_parameters, mesh_num, porder): u = Function(V) - g = Function(V).interpolate(cos(2 * pi * x) * cos(2 * pi * y)) + g = cos(2 * pi * x) * cos(2 * pi * y) bbc = DirichletBC(V, g, ((1, 3), (1, 4))) bc1 = EquationBC(a1 == L1, u, 1, bcs=[bbc]) solve(a == L, u, bcs=[bc1], solver_parameters=solver_parameters) - f.interpolate(cos(x * pi * 2)*cos(y * pi * 2)) + f = cos(x * pi * 2)*cos(y * pi * 2) return sqrt(assemble(inner(u - f, u - f) * dx)) @@ -141,22 +135,25 @@ def nonlinear_poisson_mixed(solver_parameters, mesh_num, porder): n = FacetNormal(mesh) x, y = SpatialCoordinate(mesh) - f = Function(DG).interpolate(-8 * pi * pi * cos(2 * pi * x + pi / 3) * cos(2 * pi * y)) - u1 = Function(DG).interpolate(cos(2 * pi * y) / 2) + f = -8 * pi * pi * cos(2 * pi * x + pi / 3) * cos(2 * pi * y) + u1 = cos(2 * pi * y) / 2 - a = (inner(sigma, tau) + inner(u, div(tau)) + inner(div(sigma), v)) * dx + a = inner(sigma, tau) * dx + inner(u, div(tau)) * dx + inner(div(sigma), v) * dx L = inner(u1, dot(tau, n)) * ds(1) + inner(f, v) * dx - g = Function(BDM).project(as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)])) + g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)]) - bc2 = EquationBC(inner((dot(sigma, n) - dot(g, n)), dot(tau, n)) * ds(2) == 0, w, 2, V=W.sub(0)) - bc3 = EquationBC(inner((dot(sigma, n) - dot(g, n)), dot(tau, n)) * ds(3) == 0, w, 3, V=W.sub(0)) + tau_n = dot(tau, n) + sig_n = dot(sigma, n) + g_n = dot(g, n) + bc2 = EquationBC(inner(sig_n - g_n, tau_n) * ds(2) == 0, w, 2, V=W.sub(0)) + bc3 = EquationBC(inner(sig_n - g_n, tau_n) * ds(3) == 0, w, 3, V=W.sub(0)) bc4 = DirichletBC(W.sub(0), g, 4) solve(a - L == 0, w, bcs=[bc2, bc3, bc4], solver_parameters=solver_parameters) - f.interpolate(cos(2 * pi * x + pi / 3) * cos(2 * pi * y)) - g = Function(BDM).project(as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)])) + f = cos(2 * pi * x + pi / 3) * cos(2 * pi * y) + g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)]) return sqrt(assemble(inner(u - f, u - f) * dx)), sqrt(assemble(inner(sigma - g, sigma - g) * dx)) @@ -173,28 +170,31 @@ def linear_poisson_mixed(solver_parameters, mesh_num, porder): tau, v = TestFunctions(W) x, y = SpatialCoordinate(mesh) - f = Function(DG).interpolate(-8 * pi * pi * cos(2 * pi * x + pi / 3) * cos(2 * pi * y)) - u1 = Function(DG).interpolate(cos(2 * pi * y) / 2) + f = -8 * pi * pi * cos(2 * pi * x + pi / 3) * cos(2 * pi * y) + u1 = cos(2 * pi * y) / 2 n = FacetNormal(mesh) a = (inner(sigma, tau) + inner(u, div(tau)) + inner(div(sigma), v)) * dx L = inner(u1, dot(tau, n)) * ds(1) + inner(f, v) * dx - g = Function(BDM).project(as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)])) + g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)]) w = Function(W) - bc2 = EquationBC(inner(n, tau) * inner(sigma, n) * ds(2) == inner(n, tau) * inner(g, n) * ds(2), w, 2, V=W.sub(0)) - bc3 = EquationBC(inner(n, tau) * inner(sigma, n) * ds(3) == inner(n, tau) * inner(g, n) * ds(3), w, 3, V=W.sub(0)) + tau_n = dot(tau, n) + sig_n = dot(sigma, n) + g_n = dot(g, n) + bc2 = EquationBC(inner(sig_n, tau_n) * ds(2) == inner(g_n, tau_n) * ds(2), w, 2, V=W.sub(0)) + bc3 = EquationBC(inner(sig_n, tau_n) * ds(3) == inner(g_n, tau_n) * ds(3), w, 3, V=W.sub(0)) bc4 = DirichletBC(W.sub(0), g, 4) solve(a == L, w, bcs=[bc2, bc3, bc4], solver_parameters=solver_parameters) - sigma, u = w.subfunctions - f.interpolate(cos(2 * pi * x + pi / 3) * cos(2 * pi * y)) - g = Function(BDM).project(as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)])) + f = cos(2 * pi * x + pi / 3) * cos(2 * pi * y) + g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)]) + sigma, u = w.subfunctions return sqrt(assemble(inner(u - f, u - f) * dx)), sqrt(assemble(inner(sigma - g, sigma - g) * dx)) @@ -202,7 +202,7 @@ def linear_poisson_mixed(solver_parameters, mesh_num, porder): @pytest.mark.parametrize("with_bbc", [False, True]) def test_EquationBC_poisson_matrix(eq_type, with_bbc): mat_type = "aij" - porder = 3 + porder = 2 # Test standard poisson with EquationBCs # aij @@ -235,7 +235,7 @@ def test_EquationBC_poisson_matrix(eq_type, with_bbc): def test_EquationBC_poisson_matfree(with_bbc): eq_type = "linear" mat_type = "matfree" - porder = 3 + porder = 2 # Test standard poisson with EquationBCs # matfree @@ -271,7 +271,7 @@ def test_EquationBC_poisson_matfree(with_bbc): @pytest.mark.parametrize("eq_type", ["linear", "nonlinear"]) def test_EquationBC_mixedpoisson_matrix(eq_type): mat_type = "aij" - porder = 2 + porder = 0 # Mixed poisson with EquationBCs # aij @@ -294,7 +294,7 @@ def test_EquationBC_mixedpoisson_matrix(eq_type): def test_EquationBC_mixedpoisson_matrix_fieldsplit(): mat_type = "aij" eq_type = "linear" - porder = 2 + porder = 0 # Mixed poisson with EquationBCs # aij with fieldsplit pc @@ -324,7 +324,7 @@ def test_EquationBC_mixedpoisson_matrix_fieldsplit(): def test_EquationBC_mixedpoisson_matfree_fieldsplit(): mat_type = "matfree" eq_type = "linear" - porder = 2 + porder = 0 # Mixed poisson with EquationBCs # matfree with fieldsplit pc @@ -366,7 +366,7 @@ def test_equation_bcs_pc(): v, w = split(TestFunction(V)) x, y = SpatialCoordinate(mesh) exact = cos(2 * pi * x) * cos(2 * pi * y) - g = Function(CG).interpolate(8 * pi**2 * exact) + g = 8 * pi**2 * exact F = inner(grad(u), grad(v)) * dx + inner(l, w) * dx - inner(g, v) * dx bc = EquationBC(inner((u - exact), v) * ds == 0, f, (1, 2, 3, 4), V=V.sub(0)) params = { From 2a0c03b244ef9f641fe959d026225d49dd007a14 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 6 Jan 2025 17:35:03 -0600 Subject: [PATCH 11/11] style --- firedrake/bcs.py | 2 +- firedrake/formmanipulation.py | 2 +- tests/firedrake/equation_bcs/test_equation_bcs.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 7c6821b3e3..5884907feb 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -635,7 +635,7 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col rank = len(self.f.arguments()) splitter = ExtractSubBlock() form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank]) - if form == 0: + if isinstance(form, ufl.ZeroBaseForm) or form.empty(): # form is empty, do nothing return if u is not None: diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 5e92bd8e8a..97f2b3c43e 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -73,8 +73,8 @@ def split(self, form, argument_indices): assert (len(idx) == 1 for idx in self.blocks.values()) assert (idx[0] == 0 for idx in self.blocks.values()) return form - f = map_integrand_dags(self, form) # TODO find a way to distinguish empty Forms avoiding expand_derivatives + f = map_integrand_dags(self, form) if expand_derivatives(f).empty(): # Get ZeroBaseForm with the right shape f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), diff --git a/tests/firedrake/equation_bcs/test_equation_bcs.py b/tests/firedrake/equation_bcs/test_equation_bcs.py index 3929eeaddb..fdd05b7f2e 100644 --- a/tests/firedrake/equation_bcs/test_equation_bcs.py +++ b/tests/firedrake/equation_bcs/test_equation_bcs.py @@ -190,7 +190,6 @@ def linear_poisson_mixed(solver_parameters, mesh_num, porder): solve(a == L, w, bcs=[bc2, bc3, bc4], solver_parameters=solver_parameters) - f = cos(2 * pi * x + pi / 3) * cos(2 * pi * y) g = as_vector([-2 * pi * sin(2 * pi * x + pi / 3) * cos(2 * pi * y), -2 * pi * cos(2 * pi * x + pi / 3) * sin(2 * pi * y)])