Skip to content

Conversation

@tpn
Copy link
Contributor

@tpn tpn commented Jun 10, 2025

This allows downstream passes, such as rewriting, to access information about the kernel launch for which they have been enlisted to participate.

Posting this as a PR now to get feedback on the overall approach. Assuming this solution is acceptable, I'll follow up with tests and docs.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jun 10, 2025

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@tpn
Copy link
Contributor Author

tpn commented Jun 10, 2025

Will close #280.

@gmarkall
Copy link
Contributor

/ok to test

@gmarkall
Copy link
Contributor

Using the simple benchmark from numba/numba#3003 (comment):

from numba import cuda
import numpy as np


@cuda.jit('void()')
def some_kernel_1():
    return

@cuda.jit('void(float32[:])')
def some_kernel_2(arr1):
    return

@cuda.jit('void(float32[:],float32[:])')
def some_kernel_3(arr1,arr2):
    return

@cuda.jit('void(float32[:],float32[:],float32[:])')
def some_kernel_4(arr1,arr2,arr3):
    return

@cuda.jit('void(float32[:],float32[:],float32[:],float32[:])')
def some_kernel_5(arr1,arr2,arr3,arr4):
    return

arr = cuda.device_array(10000, dtype=np.float32)

%timeit some_kernel_1[1, 1]()
%timeit some_kernel_2[1, 1](arr)
%timeit some_kernel_3[1, 1](arr,arr)
%timeit some_kernel_4[1, 1](arr,arr,arr)
%timeit some_kernel_5[1, 1](arr,arr,arr,arr)

On main:

6.6 μs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
10.6 μs ± 79.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
14.1 μs ± 65.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
17.4 μs ± 236 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
20.3 μs ± 38.5 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

On this branch:

8.6 μs ± 70.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
13 μs ± 41.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
17.2 μs ± 152 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
20.3 μs ± 35.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
24.4 μs ± 44.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

From this crude benchmark, it appears to add 2-4μs per launch, or 17-23% overhead. I don't yet know how we should consider this (or whether the benchmark is appropriate) but I think it's a consideration to keep in mind when thinking about this approach.

stream=stream,
sharedmem=sharedmem,
):
if self.specialized:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specialized kernels cannot be recompiled, so a new launch configuration would not be able to affect the compilation of a new version - so this check could be kept outside the context manager.



@dataclass(frozen=True, slots=True)
class LaunchConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be quite some overlap with dispatcher._LaunchConfiguration in this class (an observation at this point - I don't know whether it makes sense to combine them)

@kkraus14
Copy link
Contributor

Using the simple benchmark from numba/numba#3003 (comment):

from numba import cuda
import numpy as np


@cuda.jit('void()')
def some_kernel_1():
    return

@cuda.jit('void(float32[:])')
def some_kernel_2(arr1):
    return

@cuda.jit('void(float32[:],float32[:])')
def some_kernel_3(arr1,arr2):
    return

@cuda.jit('void(float32[:],float32[:],float32[:])')
def some_kernel_4(arr1,arr2,arr3):
    return

@cuda.jit('void(float32[:],float32[:],float32[:],float32[:])')
def some_kernel_5(arr1,arr2,arr3,arr4):
    return

arr = cuda.device_array(10000, dtype=np.float32)

%timeit some_kernel_1[1, 1]()
%timeit some_kernel_2[1, 1](arr)
%timeit some_kernel_3[1, 1](arr,arr)
%timeit some_kernel_4[1, 1](arr,arr,arr)
%timeit some_kernel_5[1, 1](arr,arr,arr,arr)

On main:

6.6 μs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
10.6 μs ± 79.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
14.1 μs ± 65.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
17.4 μs ± 236 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
20.3 μs ± 38.5 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

On this branch:

8.6 μs ± 70.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
13 μs ± 41.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
17.2 μs ± 152 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
20.3 μs ± 35.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
24.4 μs ± 44.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

From this crude benchmark, it appears to add 2-4μs per launch, or 17-23% overhead. I don't yet know how we should consider this (or whether the benchmark is appropriate) but I think it's a consideration to keep in mind when thinking about this approach.

Kernel launch latency is something we need to care about moving forward and is becoming more and more of a bottleneck in important workloads. Adding 2-4μs per launch for this is probably unacceptable, but that being said our launch latency right now is much higher than we'd like in general where we probably need to rework the entire launch path at some point in the not too distant future.

@gmarkall gmarkall added the 2 - In Progress Currently a work in progress label Jun 12, 2025
@gmarkall
Copy link
Contributor

we probably need to rework the entire launch path at some point in the not too distant future.

I think that's another thing that concerns me - if we implement something like this, then it constrains how we can rework the launch path in future if we have to go on supporting it.

tpn added 4 commits January 12, 2026 08:55
This allows downstream passes, such as rewriting, to access information
about the kernel launch for which they have been enlisted to
participate.
This routine raises an error if no launch config is set, which is
inevitably going to be the preferred way of obtaining the current
launch config.
This is required by cuda.coop in order to pass two-phase primitive
instances as kernel parameters without having to call the @cuda.jit
decorator with extensions=[...] up-front.
@tpn tpn force-pushed the 280-launch-config-contextvar branch from fea9f79 to 23405f3 Compare January 12, 2026 16:56
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Overview

Greptile Summary

Overview

This PR implements a thread-local mechanism to access kernel launch configuration information using Python's contextvars.ContextVar. The implementation adds:

  1. New file launchconfig.py: Defines LaunchConfig dataclass and context manager launch_config_ctx for storing launch parameters
  2. Modified dispatcher.py: Wraps kernel calls in launch_config_ctx and invokes pre-launch callbacks

Architecture

The design uses ContextVar for thread-safe and asyncio-task-safe storage of launch configuration, which is appropriate for this use case. The context manager properly handles nested kernel launches by using tokens to restore previous values.

The pre_launch_callbacks mechanism allows downstream passes (like rewriting) to register callbacks that execute just before kernel launch, enabling dynamic kernel modifications.

Key Concerns

1. Frozen Dataclass with Mutable Field

The LaunchConfig dataclass is marked frozen=True but contains a mutable List[Callable] field. This creates an inconsistent immutability guarantee and could surprise users who expect frozen dataclasses to be truly immutable.

2. Missing Exception Handling

Pre-launch callbacks are invoked without try-except blocks. If a callback fails, the error message won't indicate which callback caused the problem, making debugging difficult.

3. Performance Considerations

The __str__ method converts all args and callbacks to strings without truncation, which could be expensive for large arrays or cause memory issues.

4. Missing Tests and Documentation

As noted by the author, tests and documentation are not yet included. Given this is a core dispatch mechanism change, comprehensive testing is critical before merge.

Recommendations

  1. Add exception handling around callback invocation with clear error messages
  2. Clarify the design intent for the frozen dataclass with mutable list (or change to non-frozen)
  3. Optimize __str__ to handle large objects gracefully
  4. Add comprehensive tests covering:
    • Nested kernel launches
    • Callback exceptions
    • Thread safety
    • Callback registration and execution order
  5. Document the callback mechanism including signature requirements and execution guarantees

Confidence Score: 3/5

  • This PR introduces a solid architectural pattern but has design inconsistencies and missing error handling that should be addressed before merge.
  • Score of 3 reflects that while the core implementation using ContextVar is sound and thread-safe, there are several style and robustness issues: (1) frozen dataclass with mutable list creates misleading immutability guarantees, (2) missing exception handling for callbacks could cause unclear failures, (3) str method could have performance issues, and (4) no tests or documentation yet. The functionality appears correct but needs polish before production use.
  • Both files need attention: launchconfig.py for design clarification on immutability, and dispatcher.py for adding error handling around callback invocation.

Important Files Changed

File Analysis

Filename Score Overview
numba_cuda/numba/cuda/launchconfig.py 3/5 New file introducing thread-local launch config context. Design issues with frozen dataclass containing mutable list, and str method could be optimized.
numba_cuda/numba/cuda/dispatcher.py 3/5 Integrates launch config context into kernel dispatch. Missing exception handling for pre-launch callbacks could cause unclear errors.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +19 to +32
@dataclass(frozen=True, slots=True)
class LaunchConfig:
"""
Helper class used to encapsulate kernel launch configuration for storing
and retrieving from a thread-local ContextVar.
"""

dispatcher: "CUDADispatcher"
args: Tuple[Any, ...]
griddim: Tuple[int, int, int]
blockdim: Tuple[int, int, int]
stream: Any
sharedmem: int
pre_launch_callbacks: List[Callable[["_Kernel", "LaunchConfig"], None]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LaunchConfig dataclass is marked as frozen=True, but contains a mutable List field pre_launch_callbacks. This creates an inconsistent state where the dataclass appears immutable but the list can still be modified via append(), clear(), etc.

While this may be intentional to allow dynamic callback registration, it violates the principle of least surprise. Consider one of these alternatives:

  1. Remove frozen=True if mutability is intended
  2. Use tuple instead of List for pre_launch_callbacks to maintain true immutability
  3. Add a comment explaining why frozen + mutable list is intentional

The current design could lead to subtle bugs if code assumes the frozen dataclass is truly immutable.

Comment on lines +1627 to +1628
for callback in launch_config.pre_launch_callbacks:
callback(kernel, launch_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-launch callbacks are invoked without any exception handling. If a callback raises an exception:

  1. The kernel launch will be aborted (which may be intended)
  2. The error won't indicate which callback failed or provide context
  3. If earlier callbacks modified state, there's no rollback mechanism

Consider wrapping the callback invocation in a try-except block to provide better error messages:

Suggested change
for callback in launch_config.pre_launch_callbacks:
callback(kernel, launch_config)
for callback in launch_config.pre_launch_callbacks:
try:
callback(kernel, launch_config)
except Exception as e:
raise RuntimeError(
f"Pre-launch callback {callback} failed: {e}"
) from e

Alternatively, if callback failures should not prevent kernel launch, you could log the error and continue.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +41 to +45
def __str__(self) -> str:
a = ", ".join(map(str, self.args))
g = "×".join(map(str, self.griddim))
b = "×".join(map(str, self.blockdim))
cb = ", ".join(map(str, self.pre_launch_callbacks))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __str__ method has potential performance and usability issues:

  1. Line 42: Converting all args to strings could be expensive if args contains large arrays or GPU memory. Consider truncating or limiting the representation.
  2. Line 45: Function objects will display as <function name at 0x...> which isn't very informative for debugging.

Consider improving the string representation:

Suggested change
def __str__(self) -> str:
a = ", ".join(map(str, self.args))
g = "×".join(map(str, self.griddim))
b = "×".join(map(str, self.blockdim))
cb = ", ".join(map(str, self.pre_launch_callbacks))
def __str__(self) -> str:
# Truncate args representation if too long
args_str = ", ".join(str(arg)[:50] + "..." if len(str(arg)) > 50 else str(arg) for arg in self.args)
g = "×".join(map(str, self.griddim))
b = "×".join(map(str, self.blockdim))
# Show callback names instead of memory addresses
cb = ", ".join(getattr(c, '__name__', repr(c)) for c in self.pre_launch_callbacks)
return (

This provides more useful debug output without performance issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2 - In Progress Currently a work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants