-
-
Notifications
You must be signed in to change notification settings - Fork 258
Avoid attention masks for Qwen and Chroma #1109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
It doesn't make sense for now to implement this PR for the other models, because they all use a fixed sequence length. So the attention mask is never no-op. |
|
Looks good, just did a quick test with #1107 to include the FlashAttention code path on Windows and am seeing a speedup with Chroma, moreso when the batch size is 1 to avoid the padded text token scenario that requires attention masking. #chroma LoRA 24GB + bs2 + res 1024 #chroma LoRA 24GB + bs1 + res 1024 It would be more advantageous to train with batch size 1 and use accumulation steps to make up for reduction to take advantage of the performance gains from FlashAttention. |
|
@dxqb i just tested this, and a previous commit, it seems avoid attention mask starting giving me banding issues in my images, like those long line artifacts. using the default config for qwen in onetrainer - i did not have these on a previous commit |
|
@dxqb Is the issue due to the removal of the tensor multiplication on the mask for Qwen? |
it's not possible that this PR is the reason for this. Please look for other reasons, or make a direct comparison. This PR is mathematically identical before and after. In theory, and in tests: |
was this generated with a lightning 8-step or 4-step LoRA? |
|
@dxqb hi no this was onetrainer sample, but the lora did the same in comfyui |
|
i've just been using an older commit * and it works fine now |
I've never seen any such artifact in a OneTrainer sample. If you can reproduce this please open an issue or join the Discord to show there. |

torch SDPA automatically uses a flash attention algorithm if possible:
While torch automatically uses flash attention if there is no attention mask given at all, it does not recognize it if there is a no-op attention mask (mask is given, but no tokens are masked).
This PR detects that and uses no mask instead of a no-op mask, resulting in a significant speed-up of 20-25% in those cases. This is always the case if batch size is 1, and less often if batch size > 1.
This in principle can be combined with other upcoming features that improve performance further, see last two lines.
thank you to @FurkanGozukara for pointing out this speed difference