-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix illegal memory accesses in multistage Mma's
for k=0
#1593
base: main
Are you sure you want to change the base?
Conversation
Changed the author of the commit, no functional changes intended. |
@hwu36 Hi, could you please take a quick look at this PR? |
Thanks for reporting this, @dfyz . For cases in which one of the modes is zero, we recommend simply not including that problem in the grouped GEMM arguments. Is that a possibility for your application? |
That would imply a performance hit for two different reasons:
1 is probably more of a minor inconvenience, but 2 will really kill the performance in my case. I can try to come up with some workarounds, but isn't it better long-term to properly handle |
What is the name of the kernel used by cublas? |
I was a little unclear here: in the I see that cuBLAS also has a grouped GEMM implementation in recent CUDA versions, but that appears to just run CUTLASS under the hood, so it also crashes when Here's a quick test program I made to illustrate the observed behaviors of cuBLAS (12.5) and CUTLASS (the most recent code from
So cuBLAS and CUTLASS only actually exhibit different behavior for the non-grouped case. Note that the fix I'm proposing should fix all scenarios in the above table. |
This PR has been labeled |
Well, I still think this is something worth fixing in CUTLASS directly. For now, I implemented a workaround in |
@dfyz, cuBLAS will resolve this issue with Grouped GEMM in an upcoming release. I agree it would be good to fix in CUTLASS, but we'll need to revisit when we have more time. |
@mnicely Thank you, this sounds great! A couple of follow-up questions:
|
Hi @dfyz, it's more us spending time to review the code and any ripple efforts, internal verification and testing, and then productization. Combine this with high priority tasks and bugs; and I'm unable to commit to a date. We really appreciate the PR and your WAR may help other customers while we work on this issue. :) |
This PR has been labeled |
An eventual fix for this on the CUTLASS side would be very appreciated, but I understand that this is low-priority issue.
By the way, CUDA 12.6 is out, and this release note looks promising: However, when I run the test program from this comment as
Is this an unrelated fix, or am I doing something wrong in my test program? |
This PR has been labeled |
Ping to keep alive. |
The forward/backward passes of MLP's in mixture-of-expert models are a perfect fit for the grouped GEMM implementation in CUTLASS (for example, the
grouped_gemm
library uses CUTLASS for the forward pass). Unfortunately, when we tried to use CUTLASS for the backward pass, we started randomly gettingCUDA illegal memory access
when no tokens were assigned to an expert (this means thatk=0
when computing the gradients for its weights).I'm not sure if my analysis of the code is correct, but the problem seems to be a missing edge case in multistage MMA's. When
k=0
, thengemm_k_iterations=0
when we enter in prologue, so the following happens forA
/B
iterators:gemm_k_iterations
is-1
now, not0
, so we don't clear the maskThis implies that a fix might be just clearing the mask whenever the mask is
<=0
, not=0
. As far as I can see, this only matters for multistage prologues, but just to be safe, in this PR I'm replacing allgemm_k_iterations==0
comparisons I could find. I'm also slightly changing the grouped GEMM test, so that it introduces a problem withk=0
(the test crashes without the fix in this PR).P.S. I should also note that cuBLAS appears to handle
k=0
GEMM's correctly (no crashing, and the output matrix is filled with zeros).