Skip to content

Commit 730e299

Browse files
committed
ImplicitMatrixContext: handle empty action
1 parent 7f40504 commit 730e299

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

firedrake/assemble.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ def __init__(self,
311311
zero_bc_nodes=False,
312312
diagonal=False,
313313
weight=1.0,
314-
allocation_integral_types=None):
314+
allocation_integral_types=None,
315+
needs_zeroing=False):
315316
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters)
316317
self._mat_type = mat_type
317318
self._sub_mat_type = sub_mat_type
@@ -321,6 +322,7 @@ def __init__(self,
321322
self._diagonal = diagonal
322323
self._weight = weight
323324
self._allocation_integral_types = allocation_integral_types
325+
assert not needs_zeroing
324326

325327
def allocate(self):
326328
rank = len(self._form.arguments())
@@ -1127,7 +1129,8 @@ def _apply_bc(self, tensor, bc):
11271129
pass
11281130

11291131
def _check_tensor(self, tensor):
1130-
pass
1132+
if not isinstance(tensor, op2.Global):
1133+
raise TypeError(f"Expecting a op2.Global, got {tensor!r}.")
11311134

11321135
@staticmethod
11331136
def _as_pyop2_type(tensor, indices=None):
@@ -1143,7 +1146,7 @@ class OneFormAssembler(ParloopFormAssembler):
11431146
11441147
Parameters
11451148
----------
1146-
form : ufl.Form or slate.TensorBasehe
1149+
form : ufl.Form or slate.TensorBase
11471150
1-form.
11481151
11491152
Notes
@@ -1189,8 +1192,8 @@ def _apply_bc(self, tensor, bc):
11891192
self._apply_dirichlet_bc(tensor, bc)
11901193
elif isinstance(bc, EquationBCSplit):
11911194
bc.zero(tensor)
1192-
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
1193-
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
1195+
get_assembler(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
1196+
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
11941197
else:
11951198
raise AssertionError
11961199

firedrake/matrix_free/operators.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from firedrake.bcs import DirichletBC, EquationBCSplit
1111
from firedrake.petsc import PETSc
1212
from firedrake.utils import cached_property
13+
from firedrake.function import Function
14+
from firedrake.cofunction import Cofunction
1315

1416

1517
__all__ = ("ImplicitMatrixContext", )
@@ -107,23 +109,22 @@ def __init__(self, a, row_bcs=[], col_bcs=[],
107109

108110
# create functions from test and trial space to help
109111
# with 1-form assembly
110-
test_space, trial_space = [
111-
a.arguments()[i].function_space() for i in (0, 1)
112-
]
113-
from firedrake import function, cofunction
112+
test_space, trial_space = (
113+
arg.function_space() for arg in a.arguments()
114+
)
114115
# Need a cofunction since y receives the assembled result of Ax
115-
self._ystar = cofunction.Cofunction(test_space.dual())
116-
self._y = function.Function(test_space)
117-
self._x = function.Function(trial_space)
118-
self._xstar = cofunction.Cofunction(trial_space.dual())
116+
self._ystar = Cofunction(test_space.dual())
117+
self._y = Function(test_space)
118+
self._x = Function(trial_space)
119+
self._xstar = Cofunction(trial_space.dual())
119120

120121
# These are temporary storage for holding the BC
121122
# values during matvec application. _xbc is for
122123
# the action and ._ybc is for transpose.
123124
if len(self.bcs) > 0:
124-
self._xbc = cofunction.Cofunction(trial_space.dual())
125+
self._xbc = Cofunction(trial_space.dual())
125126
if len(self.col_bcs) > 0:
126-
self._ybc = cofunction.Cofunction(test_space.dual())
127+
self._ybc = Cofunction(test_space.dual())
127128

128129
# Get size information from template vecs on test and trial spaces
129130
trial_vec = trial_space.dof_dset.layout_vec
@@ -135,6 +136,11 @@ def __init__(self, a, row_bcs=[], col_bcs=[],
135136

136137
self.action = action(self.a, self._x)
137138
self.actionT = action(self.aT, self._y)
139+
# TODO prevent action from returning empty Forms
140+
if self.action.empty():
141+
self.action = Cofunction(test_space.dual())
142+
if self.actionT.empty():
143+
self.action = Cofunction(trial_space.dual())
138144

139145
# For assembling action(f, self._x)
140146
self.bcs_action = []
@@ -170,7 +176,6 @@ def __init__(self, a, row_bcs=[], col_bcs=[],
170176

171177
@cached_property
172178
def _diagonal(self):
173-
from firedrake import Cofunction
174179
assert self.on_diag
175180
return Cofunction(self._x.function_space().dual())
176181

0 commit comments

Comments
 (0)