Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward committed Jan 17, 2025
1 parent f850f7a commit 7c153ad
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions firedrake/preconditioners/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pyop2.codegen.builder import Pack, MatPack, DatPack
from pyop2.codegen.representation import Comparison, Literal
from pyop2.codegen.rep2loopy import register_petsc_function
from pyop2.global_kernel import compile_global_kernel

__all__ = ("PatchPC", "PlaneSmoother", "PatchSNES")

Expand Down Expand Up @@ -222,7 +223,7 @@ def matrix_funptr(form, state):

wrapper_knl_args = tuple(a.global_kernel_arg for a in args)
mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True)
kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo))
kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
return cell_kernels, int_facet_kernels


Expand Down Expand Up @@ -316,7 +317,7 @@ def residual_funptr(form, state):

wrapper_knl_args = tuple(a.global_kernel_arg for a in args)
mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True)
kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo))
kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo))
return cell_kernels, int_facet_kernels


Expand Down
4 changes: 2 additions & 2 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def __call__(self, comm, *args):
:arg comm: Communicator the execution is collective over.
:*args: Arguments to pass to the compiled kernel.
"""
func = _compile_global_kernel(self, comm)
func = compile_global_kernel(self, comm)
func(*args)

@property
Expand Down Expand Up @@ -423,7 +423,7 @@ def _generate_code_from_global_kernel(kernel):

@parallel_cache(hashkey=lambda knl, _: knl.cache_key)
@mpi.collective
def _compile_global_kernel(kernel, comm):
def compile_global_kernel(kernel, comm):
"""Compile the kernel.
Parameters
Expand Down

0 comments on commit 7c153ad

Please sign in to comment.