-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LongLoRA + Flash Attention 2 causing illigal memory access #148
Comments
Hi, Many thanks for your interest in our work. Let's take a step-by-step example to understand this flash-attention version implementation. (1) To understand the flash-attention implementation
We split the mutli-head dimension to the batch dimension. That is why the batch size is double.
It is to split tokens into groups in each batch (the original attention head dimension).
(2) Let's debug that why The most weird thing is that, there is a 520 in Based on my guess, because this function is designed for a continued pre-training, the |
Hi @yukang2017 and sorry for the late response. Thanks for taking time to look at this issue.
This is for llama 2 with a context length of 4096, and a max_length (4096) padded input sequence of 8 input tokens. |
Here's the repro script I mentioned:
|
@ArturNiederfahrenhorst Hi, do you fix the issue ? I met the same one. |
@jcao-ai No, I'm waiting for a response from @yukang2017 |
I also encountered this issue, especially when I tried to increase the "per_device_train_batch_size" parameter. It occurs, but after repeated experiments, I confirmed that it is not caused by insufficient GPU/CPU memory. |
I am also getting similar error when trying to expand the code base to support phi-2. Is anyone interested in jumping on call and trying to solve this. Not sure about the solution yet, but would be good to brainstorm, and fix it. |
Thanks for providing the LongLoRA forward functions.
Your flash-attn/non-flash-attn implementations of SSN show divergent behavior in my case.
For a repro script, please have a look at the issue I opened over at the flash-attention repo: Dao-AILab/flash-attention#670
The one without flash attention works without problems for me. I stepped my way through it and ops and shapes make sense to me.
The shift is implemented by rolling there.
The one with flash attention shows weird behaviour. The shift is not just a roll, but we mess with cu_q_lens. The code, to me, looks like it was written with token sequences longer than half of the group size in mind or something like that. For a batch with 4k context length but only 8 unpadded tokens, I end up with
cu_q_lens=[ 0, 8, 520, 16]
. For smaller group sizes, the 520 in this tensor "shrinks".Can you please elaborate the calculations or help me to fix this?
The text was updated successfully, but these errors were encountered: