2
2
import numpy
3
3
import collections
4
4
5
- from ufl import as_vector , split , ZeroBaseForm
6
- from ufl .classes import Zero , FixedIndex , ListTensor
5
+ from ufl import as_vector , split
6
+ from ufl .classes import Zero , FixedIndex , ListTensor , ZeroBaseForm
7
7
from ufl .algorithms .map_integrands import map_integrand_dags
8
8
from ufl .algorithms import expand_derivatives
9
9
from ufl .corealg .map_dag import MultiFunction , map_expr_dags
10
10
11
+ from pyop2 import MixedDat
12
+ from pyop2 .utils import as_tuple
13
+
11
14
from firedrake .petsc import PETSc
12
15
from firedrake .ufl_expr import Argument
13
- from firedrake .functionspace import MixedFunctionSpace , FunctionSpace
16
+ from firedrake .cofunction import Cofunction
17
+ from firedrake .functionspace import FunctionSpace , MixedFunctionSpace , DualSpace
14
18
15
19
16
20
def subspace (V , indices ):
17
- try :
18
- indices = tuple (indices )
19
- except TypeError :
20
- # Only one index provided.
21
- indices = (indices , )
22
21
if len (indices ) == 1 :
23
22
W = V [indices [0 ]]
24
23
W = FunctionSpace (W .mesh (), W .ufl_element ())
@@ -66,7 +65,7 @@ def split(self, form, argument_indices):
66
65
"""
67
66
args = form .arguments ()
68
67
self ._arg_cache = {}
69
- self .blocks = dict (enumerate (argument_indices ))
68
+ self .blocks = dict (enumerate (map ( as_tuple , argument_indices ) ))
70
69
if len (args ) == 0 :
71
70
# Functional can't be split
72
71
return form
@@ -75,11 +74,13 @@ def split(self, form, argument_indices):
75
74
assert (idx [0 ] == 0 for idx in self .blocks .values ())
76
75
return form
77
76
f = map_integrand_dags (self , form )
78
- f = expand_derivatives (f )
79
- if f .empty ():
80
- f = ZeroBaseForm (tuple (Argument (subspace (arg .function_space (), indices ),
77
+ # TODO find a way to distinguish empty Forms avoiding expand_derivatives
78
+ if expand_derivatives (f ).empty ():
79
+ # Get ZeroBaseForm with the right shape
80
+ f = ZeroBaseForm (tuple (Argument (subspace (arg .function_space (),
81
+ self .blocks [arg .number ()]),
81
82
arg .number (), part = arg .part ())
82
- for arg , indices in zip ( form .arguments (), argument_indices )))
83
+ for arg in form .arguments ()))
83
84
return f
84
85
85
86
expr = MultiFunction .reuse_if_untouched
@@ -109,6 +110,7 @@ def coefficient_derivative(self, o, expr, coefficients, arguments, cds):
109
110
@PETSc .Log .EventDecorator ()
110
111
def argument (self , o ):
111
112
V = o .function_space ()
113
+
112
114
if len (V ) == 1 :
113
115
# Not on a mixed space, just return ourselves.
114
116
return o
@@ -118,12 +120,6 @@ def argument(self, o):
118
120
119
121
indices = self .blocks [o .number ()]
120
122
121
- try :
122
- indices = tuple (indices )
123
- except TypeError :
124
- # Only one index provided.
125
- indices = (indices , )
126
-
127
123
W = subspace (V , indices )
128
124
a = Argument (W , o .number (), part = o .part ())
129
125
a = (a , ) if len (W ) == 1 else split (a )
@@ -141,6 +137,28 @@ def argument(self, o):
141
137
args .extend (Zero () for j in numpy .ndindex (V [i ].value_shape ))
142
138
return self ._arg_cache .setdefault (o , as_vector (args ))
143
139
140
+ def cofunction (self , o ):
141
+ V = o .function_space ()
142
+
143
+ if len (V ) == 1 :
144
+ # Not on a mixed space, just return ourselves.
145
+ return o
146
+
147
+ try :
148
+ indices , = set (self .blocks .values ())
149
+ except ValueError :
150
+ raise ValueError ("Cofunction found on an off-diagonal block" )
151
+
152
+ if len (indices ) == 1 :
153
+ i = indices [0 ]
154
+ W = V [i ]
155
+ W = DualSpace (W .mesh (), W .ufl_element ())
156
+ c = Cofunction (W , val = o .dat [i ])
157
+ else :
158
+ W = MixedFunctionSpace ([V [i ] for i in indices ])
159
+ c = Cofunction (W , val = MixedDat (o .dat [i ] for i in indices ))
160
+ return c
161
+
144
162
145
163
SplitForm = collections .namedtuple ("SplitForm" , ["indices" , "form" ])
146
164
0 commit comments