From 00b80e4806c6b734022c85464226bcdd417925be Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 31 Dec 2024 20:01:24 -0600 Subject: [PATCH] Replace empty Jacobians with ZeroBaseForm --- firedrake/adjoint_utils/variational_solver.py | 8 ++-- firedrake/assemble.py | 11 +++-- firedrake/formmanipulation.py | 45 ++++++++++--------- firedrake/matrix_free/operators.py | 16 +++++-- firedrake/preconditioners/massinv.py | 2 +- firedrake/solving_utils.py | 14 +++--- .../firedrake/slate/test_assemble_tensors.py | 11 ++--- 7 files changed, 61 insertions(+), 46 deletions(-) diff --git a/firedrake/adjoint_utils/variational_solver.py b/firedrake/adjoint_utils/variational_solver.py index c90d2668e0..79eb09096e 100644 --- a/firedrake/adjoint_utils/variational_solver.py +++ b/firedrake/adjoint_utils/variational_solver.py @@ -2,6 +2,7 @@ from functools import wraps from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock +from firedrake.ufl_expr import derivative, adjoint from ufl import replace @@ -11,7 +12,6 @@ def _ad_annotate_init(init): @no_annotations @wraps(init) def wrapper(self, *args, **kwargs): - from firedrake import derivative, adjoint, TrialFunction init(self, *args, **kwargs) self._ad_F = self.F self._ad_u = self.u_restrict @@ -20,10 +20,8 @@ def wrapper(self, *args, **kwargs): try: # Some forms (e.g. SLATE tensors) are not currently # differentiable. - dFdu = derivative(self.F, - self.u_restrict, - TrialFunction(self.u_restrict.function_space())) - self._ad_adj_F = adjoint(dFdu) + dFdu = derivative(self.F, self.u_restrict) + self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True) except (TypeError, NotImplementedError): self._ad_adj_F = None self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear} diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f3049ae01c..60c934b6c7 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -577,10 +577,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args): @staticmethod def update_tensor(assembled_base_form, tensor): if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)): - assembled_base_form.dat.copy(tensor.dat) + if isinstance(assembled_base_form, ufl.ZeroBaseForm): + tensor.dat.zero() + else: + assembled_base_form.dat.copy(tensor.dat) elif isinstance(tensor, matrix.MatrixBase): - # Uses the PETSc copy method. - assembled_base_form.petscmat.copy(tensor.petscmat) + if isinstance(assembled_base_form, ufl.ZeroBaseForm): + tensor.petscmat.zero() + else: + assembled_base_form.petscmat.copy(tensor.petscmat) else: raise NotImplementedError("Cannot update tensor of type %s" % type(tensor)) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 35a6789107..fe3598c4e6 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -2,13 +2,23 @@ import numpy import collections -from ufl import as_vector +from ufl import as_vector, split from ufl.classes import Zero, FixedIndex, ListTensor from ufl.algorithms.map_integrands import map_integrand_dags from ufl.corealg.map_dag import MultiFunction, map_expr_dags from firedrake.petsc import PETSc from firedrake.ufl_expr import Argument +from firedrake.functionspace import MixedFunctionSpace, FunctionSpace + + +def subspace(V, indices): + if len(indices) == 1: + W = V[indices[0]] + W = FunctionSpace(W.mesh(), W.ufl_element()) + else: + W = MixedFunctionSpace([V[i] for i in indices]) + return W class ExtractSubBlock(MultiFunction): @@ -26,9 +36,11 @@ def indexed(self, o, child, multiindex): indices = multiindex.indices() if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices): if len(indices) == 1: - return child.ufl_operands[indices[0]._value] + return child[indices[0]] + elif len(indices) == len(child.ufl_operands) and all(k == int(i) for k, i in enumerate(indices)): + return child else: - return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices())) + return ListTensor(*(child[i] for i in indices)) return self.expr(o, child, multiindex) index_inliner = IndexInliner() @@ -85,8 +97,6 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds): @PETSc.Log.EventDecorator() def argument(self, o): - from ufl import split - from firedrake import MixedFunctionSpace, FunctionSpace V = o.function_space() if len(V) == 1: # Not on a mixed space, just return ourselves. @@ -95,36 +105,29 @@ def argument(self, o): if o in self._arg_cache: return self._arg_cache[o] - V_is = V.subfunctions indices = self.blocks[o.number()] try: indices = tuple(indices) - nidx = len(indices) except TypeError: # Only one index provided. indices = (indices, ) - nidx = 1 - if nidx == 1: - W = V_is[indices[0]] - W = FunctionSpace(W.mesh(), W.ufl_element()) - a = (Argument(W, o.number(), part=o.part()), ) - else: - W = MixedFunctionSpace([V_is[i] for i in indices]) - a = split(Argument(W, o.number(), part=o.part())) + W = subspace(V, indices) + a = Argument(W, o.number(), part=o.part()) + a = (a, ) if len(indices) == 1 else split(a) + args = [] - for i in range(len(V_is)): + for i in range(len(V)): if i in indices: c = indices.index(i) a_ = a[c] if len(a_.ufl_shape) == 0: - args += [a_] + args.append(a_) else: - args += [a_[j] for j in numpy.ndindex(a_.ufl_shape)] + args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape)) else: - args += [Zero() - for j in numpy.ndindex(V_is[i].value_shape)] + args.extend(Zero() for j in numpy.ndindex(V[i].value_shape)) return self._arg_cache.setdefault(o, as_vector(args)) @@ -168,7 +171,7 @@ def split_form(form, diagonal=False): assert len(shape) == 2 for idx in numpy.ndindex(shape): f = splitter.split(form, idx) - if len(f.integrals()) > 0: + if not f.empty(): if diagonal: i, j = idx if i != j: diff --git a/firedrake/matrix_free/operators.py b/firedrake/matrix_free/operators.py index 3ee448730e..e3249c5d1c 100644 --- a/firedrake/matrix_free/operators.py +++ b/firedrake/matrix_free/operators.py @@ -5,11 +5,13 @@ import numpy from pyop2.mpi import internal_comm, temp_internal_comm -from firedrake.ufl_expr import adjoint, action -from firedrake.formmanipulation import ExtractSubBlock +from firedrake.ufl_expr import adjoint, action, TestFunction, TrialFunction +from firedrake.formmanipulation import ExtractSubBlock, subspace from firedrake.bcs import DirichletBC, EquationBCSplit from firedrake.petsc import PETSc from firedrake.utils import cached_property +from ufl.form import ZeroBaseForm +from ufl.algorithms import expand_derivatives __all__ = ("ImplicitMatrixContext", ) @@ -383,8 +385,14 @@ def createSubMatrix(self, mat, row_is, col_is, target=None): splitter = ExtractSubBlock() asub = splitter.split(self.a, argument_indices=(row_inds, col_inds)) - Wrow = asub.arguments()[0].function_space() - Wcol = asub.arguments()[1].function_space() + asub = expand_derivatives(asub) + if asub.empty(): + Wrow = subspace(self.a.arguments()[0].function_space(), row_inds) + Wcol = subspace(self.a.arguments()[1].function_space(), col_inds) + asub = ZeroBaseForm((TestFunction(Wrow), TrialFunction(Wcol))) + else: + Wrow = asub.arguments()[0].function_space() + Wcol = asub.arguments()[1].function_space() row_bcs = [] col_bcs = [] diff --git a/firedrake/preconditioners/massinv.py b/firedrake/preconditioners/massinv.py index 92f286c708..d29c704e8b 100644 --- a/firedrake/preconditioners/massinv.py +++ b/firedrake/preconditioners/massinv.py @@ -20,7 +20,7 @@ class MassInvPC(AssembledPC): context, keyed on ``"mu"``. """ def form(self, pc, test, trial): - _, bcs = super(MassInvPC, self).form(pc, test, trial) + _, bcs = super(MassInvPC, self).form(pc) appctx = self.get_appctx(pc) mu = appctx.get("mu", 1.0) diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index 9e843016b5..c067dce467 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -9,11 +9,13 @@ from firedrake.formmanipulation import ExtractSubBlock from firedrake.utils import cached_property from firedrake.logging import warning +from firedrake.ufl_expr import TestFunction, TrialFunction +from ufl.form import ZeroBaseForm def _make_reasons(reasons): - return dict([(getattr(reasons, r), r) - for r in dir(reasons) if not r.startswith('_')]) + return {getattr(reasons, r): r + for r in dir(reasons) if not r.startswith('_')} KSPReasons = _make_reasons(PETSc.KSP.ConvergedReason()) @@ -333,7 +335,7 @@ def split(self, fields): # Split it apart to shove in the form. subsplit = split(subu) # Permutation from field indexing to indexing of pieces - field_renumbering = dict([f, i] for i, f in enumerate(field)) + field_renumbering = {f: i for i, f in enumerate(field)} vec = [] for i, u in enumerate(us): if i in field: @@ -344,8 +346,7 @@ def split(self, fields): if u.ufl_shape == (): vec.append(u) else: - for idx in numpy.ndindex(u.ufl_shape): - vec.append(u[idx]) + vec.extend(u[idx] for idx in numpy.ndindex(u.ufl_shape)) # So now we have a new representation for the solution # vector in the old problem. For the fields we're going @@ -359,6 +360,9 @@ def split(self, fields): u = as_vector(vec) F = replace(F, {problem.u_restrict: u}) J = replace(J, {problem.u_restrict: u}) + if J.empty(): + # Handle zero Jacobian + J = ZeroBaseForm((TestFunction(V), TrialFunction(V))) if problem.Jp is not None: Jp = splitter.split(problem.Jp, argument_indices=(field, field)) Jp = replace(Jp, {problem.u_restrict: u}) diff --git a/tests/firedrake/slate/test_assemble_tensors.py b/tests/firedrake/slate/test_assemble_tensors.py index 5aff159b9b..b40d22ca2d 100644 --- a/tests/firedrake/slate/test_assemble_tensors.py +++ b/tests/firedrake/slate/test_assemble_tensors.py @@ -249,9 +249,10 @@ def test_matrix_subblocks(mesh): refs = dict(split_form(A.form)) _A = A.blocks for x, y in indices: - ref = assemble(refs[x, y]).M.values block = _A[x, y] - assert np.allclose(assemble(block).M.values, ref, rtol=1e-14) + if not block.form.empty(): + ref = assemble(refs[x, y]).M.values + assert np.allclose(assemble(block).M.values, ref, rtol=1e-14) # Mixed blocks A0101 = _A[:2, :2] @@ -267,17 +268,13 @@ def test_matrix_subblocks(mesh): A0101_10 = _A0101[1, 0] A1212_00 = _A1212[0, 0] A1212_11 = _A1212[1, 1] - A1212_01 = _A1212[0, 1] - A1212_10 = _A1212[1, 0] items = [(A0101_00, refs[(0, 0)]), (A0101_11, refs[(1, 1)]), (A0101_01, refs[(0, 1)]), (A0101_10, refs[(1, 0)]), (A1212_00, refs[(1, 1)]), - (A1212_11, refs[(2, 2)]), - (A1212_01, refs[(1, 2)]), - (A1212_10, refs[(2, 1)])] + (A1212_11, refs[(2, 2)])] # Test assembly of blocks of mixed blocks for tensor, form in items: