Skip to content

Commit

Permalink
Fieldsplit: replace empty Forms with ZeroBaseForm (#3947)
Browse files Browse the repository at this point in the history
* Restricted Cofunction RHS

* Fix BCs on Cofunction

* LinearSolver: check function spaces

* assemble(form, zero_bc_nodes=True) as default

* Fix FunctionAssignBlock

* Allow Cofunction.assign take in constants

* remove BaseFormAssembler test

* only supply relevant kwargs to OneFormAssembler

* Only interpolate the residual, not every cofunction in the RHS

* Fix tests

* Fix adjoint utils

* More robust test for (unrestricted) Cofunction RHS

* Replace empty Jacobians with ZeroBaseForm

* Do not split off-diagonal blocks if we only want the diagonal

* Zero-simplify slate Tensors

* set bcs directly on diagonal Cofunction

* ImplicitMatrixContext: handle empty action

* Only extract constants referenced in the kernel

* Adjoint: only skip expand_derivatives if necessary

* EquationBC: do not reconstruct empty Forms

* lower degree for EquationBC tests

* Update .github/workflows/build.yml
  • Loading branch information
pbrubeck authored Jan 15, 2025
1 parent ad9fe2c commit 9c5ec2f
Show file tree
Hide file tree
Showing 27 changed files with 310 additions and 320 deletions.
3 changes: 1 addition & 2 deletions demos/netgen/netgen_mesh.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ We will now show how to solve the Poisson problem on a high-order mesh, of order

bc = DirichletBC(V, 0.0, [1])
A = assemble(a, bcs=bc)
b = assemble(l)
bc.apply(b)
b = assemble(l, bcs=bc)
solve(A, sol, b, solver_parameters={"ksp_type": "cg", "pc_type": "lu"})

VTKFile("output/Sphere.pvd").write(sol)
Expand Down
6 changes: 3 additions & 3 deletions firedrake/adjoint_utils/blocks/dirichlet_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
adj_output = None
for adj_input in adj_inputs:
if isconstant(c):
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
if self.function_space != self.parent_space:
vec = extract_bc_subvector(
Expand Down Expand Up @@ -88,11 +88,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
# you can even use the Function outside its domain.
# For now we will just assume the FunctionSpace is the same for
# the BC and the Function.
adj_value = firedrake.Function(self.parent_space.dual())
adj_value = firedrake.Function(self.parent_space)
adj_input.apply(adj_value)
r = extract_bc_subvector(
adj_value, c.function_space(), bc
)
).riesz_representation("l2")
if adj_output is None:
adj_output = r
else:
Expand Down
1 change: 1 addition & 0 deletions firedrake/adjoint_utils/blocks/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
)
diff_expr_assembled = firedrake.Function(adj_input_func.function_space())
diff_expr_assembled.interpolate(ufl.conj(diff_expr))
diff_expr_assembled = diff_expr_assembled.riesz_representation(riesz_map="l2")
adj_output = firedrake.Function(
R, val=firedrake.assemble(ufl.Action(diff_expr_assembled, adj_input_func))
)
Expand Down
26 changes: 12 additions & 14 deletions firedrake/adjoint_utils/blocks/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,12 @@ def _assemble_dFdu_adj(self, dFdu_adj_form, **kwargs):

def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
kwargs = self.assemble_kwargs.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
kwargs["bcs"] = bcs
dFdu = self._assemble_dFdu_adj(dFdu_adj_form, **kwargs)
dFdu = firedrake.assemble(dFdu_adj_form, bcs=bcs, **self.assemble_kwargs)

for bc in bcs:
bc.apply(dJdu)
bc.zero(dJdu)

adj_sol = firedrake.Function(self.function_space)
firedrake.solve(
Expand All @@ -219,10 +217,8 @@ def _assemble_and_solve_adj_eq(self, dFdu_adj_form, dJdu, compute_bdy):
return adj_sol, adj_sol_bdy

def _compute_adj_bdy(self, adj_sol, adj_sol_bdy, dFdu_adj_form, dJdu):
adj_sol_bdy = firedrake.Function(
self.function_space.dual(), dJdu.dat - firedrake.assemble(
firedrake.action(dFdu_adj_form, adj_sol)).dat)
return adj_sol_bdy
adj_sol_bdy = firedrake.assemble(dJdu - firedrake.action(dFdu_adj_form, adj_sol))
return adj_sol_bdy.riesz_representation("l2")

def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
prepared=None):
Expand Down Expand Up @@ -264,8 +260,11 @@ def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx,
return dFdm

dFdm = -firedrake.derivative(F_form, c_rep, trial_function)
dFdm = firedrake.adjoint(dFdm)
dFdm = dFdm * adj_sol
if isinstance(dFdm, ufl.Form):
dFdm = firedrake.adjoint(dFdm)
dFdm = firedrake.action(dFdm, adj_sol)
else:
dFdm = dFdm(adj_sol)
dFdm = firedrake.assemble(dFdm, **self.assemble_kwargs)
return dFdm

Expand Down Expand Up @@ -654,9 +653,8 @@ def _forward_solve(self, lhs, rhs, func, bcs, **kwargs):
def _adjoint_solve(self, dJdu, compute_bdy):
dJdu_copy = dJdu.copy()
# Homogenize and apply boundary conditions on adj_dFdu and dJdu.
bcs = self._homogenize_bcs()
for bc in bcs:
bc.apply(dJdu)
for bc in self.bcs:
bc.zero(dJdu)

if (
self._ad_solvers["forward_nlvs"]._problem._constant_jacobian
Expand Down Expand Up @@ -876,7 +874,7 @@ def __init__(self, source, target_space, target, bcs=[], **kwargs):
self.add_dependency(bc, no_duplicates=True)

def apply_mixedmass(self, a):
b = firedrake.Function(self.target_space)
b = firedrake.Function(self.target_space.dual())
with a.dat.vec_ro as vsrc, b.dat.vec_wo as vrhs:
self.mixed_mass.mult(vsrc, vrhs)
return b
Expand Down
13 changes: 8 additions & 5 deletions firedrake/adjoint_utils/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
from pyadjoint.tape import get_working_tape, stop_annotating, annotate_tape, no_annotations
from firedrake.adjoint_utils.blocks import NonlinearVariationalSolveBlock
from firedrake.ufl_expr import derivative, adjoint
from ufl import replace


Expand All @@ -11,7 +12,6 @@ def _ad_annotate_init(init):
@no_annotations
@wraps(init)
def wrapper(self, *args, **kwargs):
from firedrake import derivative, adjoint, TrialFunction
init(self, *args, **kwargs)
self._ad_F = self.F
self._ad_u = self.u_restrict
Expand All @@ -20,10 +20,13 @@ def wrapper(self, *args, **kwargs):
try:
# Some forms (e.g. SLATE tensors) are not currently
# differentiable.
dFdu = derivative(self.F,
self.u_restrict,
TrialFunction(self.u_restrict.function_space()))
self._ad_adj_F = adjoint(dFdu)
dFdu = derivative(self.F, self.u_restrict)
try:
self._ad_adj_F = adjoint(dFdu)
except ValueError:
# Try again without expanding derivatives,
# as dFdu might have been simplied to an empty Form
self._ad_adj_F = adjoint(dFdu, derivatives_expanded=True)
except (TypeError, NotImplementedError):
self._ad_adj_F = None
self._ad_kwargs = {'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear}
Expand Down
75 changes: 44 additions & 31 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def assemble(expr, *args, **kwargs):
zero_bc_nodes : bool
If `True`, set the boundary condition nodes in the
output tensor to zero rather than to the values prescribed by the
boundary condition. Default is `False`.
boundary condition. Default is `True`.
diagonal : bool
If assembling a matrix is it diagonal?
weight : float
Expand Down Expand Up @@ -143,7 +143,6 @@ def get_assembler(form, *args, **kwargs):
"""
is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False)
bcs = kwargs.get('bcs', None)
fc_params = kwargs.get('form_compiler_parameters', None)
if isinstance(form, ufl.form.BaseForm) and not is_base_form_preprocessed:
mat_type = kwargs.get('mat_type', None)
Expand All @@ -155,8 +154,13 @@ def get_assembler(form, *args, **kwargs):
if len(form.arguments()) == 0:
return ZeroFormAssembler(form, form_compiler_parameters=fc_params)
elif len(form.arguments()) == 1 or diagonal:
return OneFormAssembler(form, *args, bcs=bcs, form_compiler_parameters=fc_params, needs_zeroing=kwargs.get('needs_zeroing', True),
zero_bc_nodes=kwargs.get('zero_bc_nodes', False), diagonal=diagonal)
return OneFormAssembler(form, *args,
bcs=kwargs.get("bcs", None),
form_compiler_parameters=fc_params,
needs_zeroing=kwargs.get("needs_zeroing", True),
zero_bc_nodes=kwargs.get("zero_bc_nodes", True),
diagonal=diagonal,
weight=kwargs.get("weight", 1.0))
elif len(form.arguments()) == 2:
return TwoFormAssembler(form, *args, **kwargs)
else:
Expand Down Expand Up @@ -308,7 +312,7 @@ def __init__(self,
sub_mat_type=None,
options_prefix=None,
appctx=None,
zero_bc_nodes=False,
zero_bc_nodes=True,
diagonal=False,
weight=1.0,
allocation_integral_types=None):
Expand Down Expand Up @@ -381,6 +385,12 @@ def visitor(e, *operands):
visited = {}
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited)

# Apply BCs after assembly
rank = len(self._form.arguments())
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
for bc in self._bcs:
bc.zero(result)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
return tensor
Expand All @@ -405,8 +415,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
if rank == 0:
assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params)
elif rank == 1 or (rank == 2 and self._diagonal):
assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal)
assembler = OneFormAssembler(form, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight)
elif rank == 2:
assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
Expand Down Expand Up @@ -577,10 +587,15 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
@staticmethod
def update_tensor(assembled_base_form, tensor):
if isinstance(tensor, (firedrake.Function, firedrake.Cofunction)):
assembled_base_form.dat.copy(tensor.dat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.dat.zero()
else:
assembled_base_form.dat.copy(tensor.dat)
elif isinstance(tensor, matrix.MatrixBase):
# Uses the PETSc copy method.
assembled_base_form.petscmat.copy(tensor.petscmat)
if isinstance(assembled_base_form, ufl.ZeroBaseForm):
tensor.petscmat.zeroEntries()
else:
assembled_base_form.petscmat.copy(tensor.petscmat)
else:
raise NotImplementedError("Cannot update tensor of type %s" % type(tensor))

Expand Down Expand Up @@ -807,9 +822,9 @@ def restructure_base_form(expr, visited=None):
return ufl.action(expr, ustar)

# -- Case (6) -- #
if isinstance(expr, ufl.FormSum) and all(isinstance(c, ufl.core.base_form_operator.BaseFormOperator) for c in expr.components()):
# Return ufl.Sum
return sum([c for c in expr.components()])
if isinstance(expr, ufl.FormSum) and all(ufl.duals.is_dual(a.function_space()) for a in expr.arguments()):
# Return ufl.Sum if we are assembling a FormSum with Coarguments (a primal expression)
return sum(w*c for w, c in zip(expr.weights(), expr.components()))
return expr

@staticmethod
Expand Down Expand Up @@ -1138,7 +1153,7 @@ class OneFormAssembler(ParloopFormAssembler):
Parameters
----------
form : ufl.Form or slate.TensorBasehe
form : ufl.Form or slate.TensorBase
1-form.
Notes
Expand All @@ -1149,14 +1164,15 @@ class OneFormAssembler(ParloopFormAssembler):

@classmethod
def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=True, diagonal=False, weight=1.0):
bcs = solving._extract_bcs(bcs)
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal, weight

@FormAssembler._skip_if_initialised
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=True, diagonal=False, weight=1.0):
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
self._weight = weight
self._diagonal = diagonal
self._zero_bc_nodes = zero_bc_nodes
if self._diagonal and any(isinstance(bc, EquationBCSplit) for bc in self._bcs):
Expand Down Expand Up @@ -1185,23 +1201,21 @@ def _apply_bc(self, tensor, bc):
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if not self._zero_bc_nodes:
tensor_func = tensor.riesz_representation(riesz_map="l2")
if self._diagonal:
bc.set(tensor_func, 1)
else:
bc.apply(tensor_func)
tensor.assign(tensor_func.riesz_representation(riesz_map="l2"))
if self._diagonal:
bc.set(tensor, self._weight)
elif not self._zero_bc_nodes:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
else:
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space():
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")

@staticmethod
Expand Down Expand Up @@ -2127,14 +2141,13 @@ def iter_active_coefficients(form, kinfo):

@staticmethod
def iter_constants(form, kinfo):
"""Yield the form constants"""
"""Yield the form constants referenced in ``kinfo``."""
if isinstance(form, slate.TensorBase):
for const in form.constants():
yield const
all_constants = form.constants()
else:
all_constants = extract_firedrake_constants(form)
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]
for constant_index in kinfo.constant_numbers:
yield all_constants[constant_index]

@staticmethod
def index_function_spaces(form, indices):
Expand Down
8 changes: 4 additions & 4 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,10 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
return
rank = len(self.f.arguments())
splitter = ExtractSubBlock()
if rank == 1:
form = splitter.split(self.f, argument_indices=(row_field, ))
elif rank == 2:
form = splitter.split(self.f, argument_indices=(row_field, col_field))
form = splitter.split(self.f, argument_indices=(row_field, col_field)[:rank])
if isinstance(form, ufl.ZeroBaseForm) or form.empty():
# form is empty, do nothing
return
if u is not None:
form = firedrake.replace(form, {self.u: u})
if action_x is not None:
Expand Down
6 changes: 4 additions & 2 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ def assign(self, expr, subset=None, expr_from_assemble=False):
return self.assign(
assembled_expr, subset=subset,
expr_from_assemble=True)

raise ValueError('Cannot assign %s' % expr)
else:
from firedrake.assign import Assigner
Assigner(self, expr, subset).assign()
return self

def riesz_representation(self, riesz_map='L2', **solver_options):
"""Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map.
Expand Down
Loading

0 comments on commit 9c5ec2f

Please sign in to comment.