From 7c153ad687cdc5af00685f00b0c08c4e8d643013 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 17 Jan 2025 13:47:06 +0000 Subject: [PATCH] fixup --- firedrake/preconditioners/patch.py | 5 +++-- pyop2/global_kernel.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/firedrake/preconditioners/patch.py b/firedrake/preconditioners/patch.py index 5e9d0d4fa0..d79a03b8a1 100644 --- a/firedrake/preconditioners/patch.py +++ b/firedrake/preconditioners/patch.py @@ -27,6 +27,7 @@ from pyop2.codegen.builder import Pack, MatPack, DatPack from pyop2.codegen.representation import Comparison, Literal from pyop2.codegen.rep2loopy import register_petsc_function +from pyop2.global_kernel import compile_global_kernel __all__ = ("PatchPC", "PlaneSmoother", "PatchSNES") @@ -222,7 +223,7 @@ def matrix_funptr(form, state): wrapper_knl_args = tuple(a.global_kernel_arg for a in args) mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True) - kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo)) + kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo)) return cell_kernels, int_facet_kernels @@ -316,7 +317,7 @@ def residual_funptr(form, state): wrapper_knl_args = tuple(a.global_kernel_arg for a in args) mod = op2.GlobalKernel(kinfo.kernel, wrapper_knl_args, subset=True) - kernels.append(CompiledKernel(mod.compile(iterset.comm), kinfo)) + kernels.append(CompiledKernel(compile_global_kernel(mod, iterset.comm), kinfo)) return cell_kernels, int_facet_kernels diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 433a2992f4..aac7d30c3d 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -329,7 +329,7 @@ def __call__(self, comm, *args): :arg comm: Communicator the execution is collective over. :*args: Arguments to pass to the compiled kernel. """ - func = _compile_global_kernel(self, comm) + func = compile_global_kernel(self, comm) func(*args) @property @@ -423,7 +423,7 @@ def _generate_code_from_global_kernel(kernel): @parallel_cache(hashkey=lambda knl, _: knl.cache_key) @mpi.collective -def _compile_global_kernel(kernel, comm): +def compile_global_kernel(kernel, comm): """Compile the kernel. Parameters