Skip to content

Commit 245463d

Browse files
majosminducer
authored andcommitted
get mpi_communicator from actx instead of dcoll
1 parent 1bb9a14 commit 245463d

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

grudge/reductions.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from pytools import memoize_in
8181

8282
import grudge.dof_desc as dof_desc
83+
from grudge.array_context import MPIBasedArrayContext
8384
from grudge.discretization import DiscretizationCollection
8485

8586

@@ -128,16 +129,17 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Scalar:
128129
:class:`~arraycontext.ArrayContainer`.
129130
:returns: a device scalar denoting the nodal sum.
130131
"""
131-
comm = dcoll.mpi_communicator
132-
if comm is None:
132+
from arraycontext import get_container_context_recursively
133+
actx = get_container_context_recursively(vec)
134+
135+
if not isinstance(actx, MPIBasedArrayContext):
133136
return nodal_sum_loc(dcoll, dd, vec)
134137

138+
comm = actx.mpi_communicator
139+
135140
# NOTE: Do not move, we do not want to import mpi4py in single-rank computations
136141
from mpi4py import MPI
137142

138-
from arraycontext import get_container_context_recursively
139-
actx = get_container_context_recursively(vec)
140-
141143
return actx.from_numpy(
142144
comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM))
143145

@@ -174,13 +176,16 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal
174176
:arg initial: an optional initial value. Defaults to `numpy.inf`.
175177
:returns: a device scalar denoting the nodal minimum.
176178
"""
177-
comm = dcoll.mpi_communicator
178-
if comm is None:
179+
from arraycontext import get_container_context_recursively
180+
actx = get_container_context_recursively(vec)
181+
182+
if not isinstance(actx, MPIBasedArrayContext):
179183
return nodal_min_loc(dcoll, dd, vec, initial=initial)
180184

185+
comm = actx.mpi_communicator
186+
181187
# NOTE: Do not move, we do not want to import mpi4py in single-rank computations
182188
from mpi4py import MPI
183-
actx = vec.array_context
184189

185190
return actx.from_numpy(
186191
comm.allreduce(
@@ -231,13 +236,16 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal
231236
:arg initial: an optional initial value. Defaults to `-numpy.inf`.
232237
:returns: a device scalar denoting the nodal maximum.
233238
"""
234-
comm = dcoll.mpi_communicator
235-
if comm is None:
239+
from arraycontext import get_container_context_recursively
240+
actx = get_container_context_recursively(vec)
241+
242+
if not isinstance(actx, MPIBasedArrayContext):
236243
return nodal_max_loc(dcoll, dd, vec, initial=initial)
237244

245+
comm = actx.mpi_communicator
246+
238247
# NOTE: Do not move, we do not want to import mpi4py in single-rank computations
239248
from mpi4py import MPI
240-
actx = vec.array_context
241249

242250
return actx.from_numpy(
243251
comm.allreduce(

grudge/trace_pair.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def __init__(self,
410410
bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_rank))
411411

412412
local_bdry_data = project(dcoll, volume_dd, bdry_dd, array_container)
413-
comm = dcoll.mpi_communicator
413+
comm = actx.mpi_communicator
414414
assert comm is not None
415415

416416
self.dcoll = dcoll

0 commit comments

Comments
 (0)