-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA]: Provide generic and safe C++ interfaces for warp shuffle #2976
Comments
Hi, if I understand correctly, you might want something like this. I have only done this for one of the functions mentioned. If this is correct or if any additional features need to be added, I would happily like to take up this problem.
NOTE: The code provided is like an MWE to understand whether it meets the actual expectations. |
Hi @soumikiith , this is definitely going in the right direction. We would need some additional constraints though
|
Hi @miscco , |
Also regarding extended floating point support, have a look at e.g cmath extensions |
please note that |
We can bitcast __half and __nvbfloat just fine and should easily extend that to any other extended floating point vector type |
we need to extend this mechanism to raw array, cuda::std::array, cuda::std::pair, composition of them, etc.. |
Hi, I am currently finishing off on the issue. Sorry for the delayed response. I just want to check whether my implementation is aligned with the expectations of the issue. I am sharing a minimal code. If anything needs to be addressed, please let me know. Waiting for any further comments or any review. Thank You.
P.S. I am implementing the code inside <cuda/cmath> as you mentioned. |
No need to apologize, we deeply appreciate you putting your time in.
Generally we prefer to discuss within a draft PR if there are still design issues. We recently changed our CI so that it does not run on draft PRs so there is no drawback in opening a draft PR, even if it is just a sketch.
We are generally prefering variable templates over types. Note that we might need this for C++11 support and compilers with insufficient variable templates support
This should just be directly inlined into the code. This would add a function call in debug mode when its not necessary
To be sure, we want this code to live in something like |
"Hi, I'd like to work on this issue. Could you assign it to me?" |
@vrajvaghela89 there is already @soumikiith working on it |
@fbusato Why not use cooperative groups shfl? Is it because of some very unusual mask pattern? |
that's a good point. Do you mean using CG as a backend? |
Either backend or just use CG where the proposed functionality would be used. |
I guess
the main use cases are the extended floating point types |
Is this a duplicate?
Area
libcu++
Is your feature request related to a problem? Please describe.
CUDA provides warp shuffle intrinsics that support a limited set of types. Secondly, they there are not check to validate the inputs
see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-functions
Describe the solution you'd like
Provide:
cuda::shfl(T var, int srcLane, unsigned mask = 0xFFFFFFFF, int width=warpSize)
cuda::shfl_up(T var, int delta, unsigned mask = 0xFFFFFFFF, int width=warpSize)
cuda::shfl_down(T var, int delta, unsigned mask = 0xFFFFFFFF, int width=warpSize)
cuda::shfl_xor(T var, int laneMask, unsigned mask = 0xFFFFFFFF, int width=warpSize)
Features and checks:
width
is a power of two and1 <= width <= WarpSize
mask
is a subset of__activemask()
cuda::shfl
srcLane
is part ofmask
0 <= srcLane < width
([optional] no modulo behavior)cuda::shfl_up
1 < delta < width
max(laneid - delta, 0)
is part ofmask
cuda::shfl_down
1 < delta < width
min(laneid + delta, width)
is part ofmask
cuda::shfl_xor
clamp(laneid ^ laneMask, 0, width)
is part ofmask
mask
value is the same for all participating lanes (__match_all_sync()
) [optional]Describe alternatives you've considered
An alternative could use
mask
at the end of the parameter listAdditional context
No response
The text was updated successfully, but these errors were encountered: