-
Notifications
You must be signed in to change notification settings - Fork 54
Implement a thread-local means to access kernel launch config. #288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Will close #280. |
|
/ok to test |
|
Using the simple benchmark from numba/numba#3003 (comment): from numba import cuda
import numpy as np
@cuda.jit('void()')
def some_kernel_1():
return
@cuda.jit('void(float32[:])')
def some_kernel_2(arr1):
return
@cuda.jit('void(float32[:],float32[:])')
def some_kernel_3(arr1,arr2):
return
@cuda.jit('void(float32[:],float32[:],float32[:])')
def some_kernel_4(arr1,arr2,arr3):
return
@cuda.jit('void(float32[:],float32[:],float32[:],float32[:])')
def some_kernel_5(arr1,arr2,arr3,arr4):
return
arr = cuda.device_array(10000, dtype=np.float32)
%timeit some_kernel_1[1, 1]()
%timeit some_kernel_2[1, 1](arr)
%timeit some_kernel_3[1, 1](arr,arr)
%timeit some_kernel_4[1, 1](arr,arr,arr)
%timeit some_kernel_5[1, 1](arr,arr,arr,arr)On On this branch: From this crude benchmark, it appears to add 2-4μs per launch, or 17-23% overhead. I don't yet know how we should consider this (or whether the benchmark is appropriate) but I think it's a consideration to keep in mind when thinking about this approach. |
| stream=stream, | ||
| sharedmem=sharedmem, | ||
| ): | ||
| if self.specialized: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specialized kernels cannot be recompiled, so a new launch configuration would not be able to affect the compilation of a new version - so this check could be kept outside the context manager.
|
|
||
|
|
||
| @dataclass(frozen=True, slots=True) | ||
| class LaunchConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be quite some overlap with dispatcher._LaunchConfiguration in this class (an observation at this point - I don't know whether it makes sense to combine them)
Kernel launch latency is something we need to care about moving forward and is becoming more and more of a bottleneck in important workloads. Adding 2-4μs per launch for this is probably unacceptable, but that being said our launch latency right now is much higher than we'd like in general where we probably need to rework the entire launch path at some point in the not too distant future. |
I think that's another thing that concerns me - if we implement something like this, then it constrains how we can rework the launch path in future if we have to go on supporting it. |
019d60d to
bba597d
Compare
bba597d to
fea9f79
Compare
This allows downstream passes, such as rewriting, to access information about the kernel launch for which they have been enlisted to participate.
This routine raises an error if no launch config is set, which is inevitably going to be the preferred way of obtaining the current launch config.
This is required by cuda.coop in order to pass two-phase primitive instances as kernel parameters without having to call the @cuda.jit decorator with extensions=[...] up-front.
fea9f79 to
23405f3
Compare
Greptile OverviewGreptile SummaryOverviewThis PR implements a thread-local mechanism to access kernel launch configuration information using Python's
ArchitectureThe design uses The Key Concerns1. Frozen Dataclass with Mutable FieldThe 2. Missing Exception HandlingPre-launch callbacks are invoked without try-except blocks. If a callback fails, the error message won't indicate which callback caused the problem, making debugging difficult. 3. Performance ConsiderationsThe 4. Missing Tests and DocumentationAs noted by the author, tests and documentation are not yet included. Given this is a core dispatch mechanism change, comprehensive testing is critical before merge. Recommendations
Confidence Score: 3/5
Important Files ChangedFile Analysis
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| @dataclass(frozen=True, slots=True) | ||
| class LaunchConfig: | ||
| """ | ||
| Helper class used to encapsulate kernel launch configuration for storing | ||
| and retrieving from a thread-local ContextVar. | ||
| """ | ||
|
|
||
| dispatcher: "CUDADispatcher" | ||
| args: Tuple[Any, ...] | ||
| griddim: Tuple[int, int, int] | ||
| blockdim: Tuple[int, int, int] | ||
| stream: Any | ||
| sharedmem: int | ||
| pre_launch_callbacks: List[Callable[["_Kernel", "LaunchConfig"], None]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LaunchConfig dataclass is marked as frozen=True, but contains a mutable List field pre_launch_callbacks. This creates an inconsistent state where the dataclass appears immutable but the list can still be modified via append(), clear(), etc.
While this may be intentional to allow dynamic callback registration, it violates the principle of least surprise. Consider one of these alternatives:
- Remove
frozen=Trueif mutability is intended - Use
tupleinstead ofListforpre_launch_callbacksto maintain true immutability - Add a comment explaining why frozen + mutable list is intentional
The current design could lead to subtle bugs if code assumes the frozen dataclass is truly immutable.
| for callback in launch_config.pre_launch_callbacks: | ||
| callback(kernel, launch_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pre-launch callbacks are invoked without any exception handling. If a callback raises an exception:
- The kernel launch will be aborted (which may be intended)
- The error won't indicate which callback failed or provide context
- If earlier callbacks modified state, there's no rollback mechanism
Consider wrapping the callback invocation in a try-except block to provide better error messages:
| for callback in launch_config.pre_launch_callbacks: | |
| callback(kernel, launch_config) | |
| for callback in launch_config.pre_launch_callbacks: | |
| try: | |
| callback(kernel, launch_config) | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Pre-launch callback {callback} failed: {e}" | |
| ) from e |
Alternatively, if callback failures should not prevent kernel launch, you could log the error and continue.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| def __str__(self) -> str: | ||
| a = ", ".join(map(str, self.args)) | ||
| g = "×".join(map(str, self.griddim)) | ||
| b = "×".join(map(str, self.blockdim)) | ||
| cb = ", ".join(map(str, self.pre_launch_callbacks)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __str__ method has potential performance and usability issues:
- Line 42: Converting all args to strings could be expensive if args contains large arrays or GPU memory. Consider truncating or limiting the representation.
- Line 45: Function objects will display as
<function name at 0x...>which isn't very informative for debugging.
Consider improving the string representation:
| def __str__(self) -> str: | |
| a = ", ".join(map(str, self.args)) | |
| g = "×".join(map(str, self.griddim)) | |
| b = "×".join(map(str, self.blockdim)) | |
| cb = ", ".join(map(str, self.pre_launch_callbacks)) | |
| def __str__(self) -> str: | |
| # Truncate args representation if too long | |
| args_str = ", ".join(str(arg)[:50] + "..." if len(str(arg)) > 50 else str(arg) for arg in self.args) | |
| g = "×".join(map(str, self.griddim)) | |
| b = "×".join(map(str, self.blockdim)) | |
| # Show callback names instead of memory addresses | |
| cb = ", ".join(getattr(c, '__name__', repr(c)) for c in self.pre_launch_callbacks) | |
| return ( |
This provides more useful debug output without performance issues.
This allows downstream passes, such as rewriting, to access information about the kernel launch for which they have been enlisted to participate.
Posting this as a PR now to get feedback on the overall approach. Assuming this solution is acceptable, I'll follow up with tests and docs.