|
80 | 80 | from pytools import memoize_in
|
81 | 81 |
|
82 | 82 | import grudge.dof_desc as dof_desc
|
| 83 | +from grudge.array_context import MPIBasedArrayContext |
83 | 84 | from grudge.discretization import DiscretizationCollection
|
84 | 85 |
|
85 | 86 |
|
@@ -128,16 +129,17 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Scalar:
|
128 | 129 | :class:`~arraycontext.ArrayContainer`.
|
129 | 130 | :returns: a device scalar denoting the nodal sum.
|
130 | 131 | """
|
131 |
| - comm = dcoll.mpi_communicator |
132 |
| - if comm is None: |
| 132 | + from arraycontext import get_container_context_recursively |
| 133 | + actx = get_container_context_recursively(vec) |
| 134 | + |
| 135 | + if not isinstance(actx, MPIBasedArrayContext): |
133 | 136 | return nodal_sum_loc(dcoll, dd, vec)
|
134 | 137 |
|
| 138 | + comm = actx.mpi_communicator |
| 139 | + |
135 | 140 | # NOTE: Do not move, we do not want to import mpi4py in single-rank computations
|
136 | 141 | from mpi4py import MPI
|
137 | 142 |
|
138 |
| - from arraycontext import get_container_context_recursively |
139 |
| - actx = get_container_context_recursively(vec) |
140 |
| - |
141 | 143 | return actx.from_numpy(
|
142 | 144 | comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM))
|
143 | 145 |
|
@@ -174,13 +176,16 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal
|
174 | 176 | :arg initial: an optional initial value. Defaults to `numpy.inf`.
|
175 | 177 | :returns: a device scalar denoting the nodal minimum.
|
176 | 178 | """
|
177 |
| - comm = dcoll.mpi_communicator |
178 |
| - if comm is None: |
| 179 | + from arraycontext import get_container_context_recursively |
| 180 | + actx = get_container_context_recursively(vec) |
| 181 | + |
| 182 | + if not isinstance(actx, MPIBasedArrayContext): |
179 | 183 | return nodal_min_loc(dcoll, dd, vec, initial=initial)
|
180 | 184 |
|
| 185 | + comm = actx.mpi_communicator |
| 186 | + |
181 | 187 | # NOTE: Do not move, we do not want to import mpi4py in single-rank computations
|
182 | 188 | from mpi4py import MPI
|
183 |
| - actx = vec.array_context |
184 | 189 |
|
185 | 190 | return actx.from_numpy(
|
186 | 191 | comm.allreduce(
|
@@ -231,13 +236,16 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal
|
231 | 236 | :arg initial: an optional initial value. Defaults to `-numpy.inf`.
|
232 | 237 | :returns: a device scalar denoting the nodal maximum.
|
233 | 238 | """
|
234 |
| - comm = dcoll.mpi_communicator |
235 |
| - if comm is None: |
| 239 | + from arraycontext import get_container_context_recursively |
| 240 | + actx = get_container_context_recursively(vec) |
| 241 | + |
| 242 | + if not isinstance(actx, MPIBasedArrayContext): |
236 | 243 | return nodal_max_loc(dcoll, dd, vec, initial=initial)
|
237 | 244 |
|
| 245 | + comm = actx.mpi_communicator |
| 246 | + |
238 | 247 | # NOTE: Do not move, we do not want to import mpi4py in single-rank computations
|
239 | 248 | from mpi4py import MPI
|
240 |
| - actx = vec.array_context |
241 | 249 |
|
242 | 250 | return actx.from_numpy(
|
243 | 251 | comm.allreduce(
|
|
0 commit comments