-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XFormer fails when passing attention mask while using bfloat and key's sequence length not being a multiple of 8 #9637
Labels
bug
Something isn't working
Comments
solve this by adding the below code to XFormersAttnProcessor if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
+ if attention_mask.dtype == torch.bfloat16 and attention_mask.shape[-1] % 8 != 0:
+ mask = torch.zeros(
+ (attention_mask.shape[0], attention_mask.shape[1], math.ceil(attention_mask.shape[-1] / 8) * 8),
+ device=attention_mask.device,
+ dtype=attention_mask.dtype
+ )
+ mask[:, :, :attention_mask.shape[-1]] = attention_mask
+ attention_mask = mask[:, :, :attention_mask.shape[-1]]
attention_mask = attention_mask.expand(-1, query_tokens, -1) Don't know if a pr is needed about this🤔. |
dhmbb2
changed the title
XFormer fails when passing attention mask while using bfloat and key's sequence length is not a multiple of 8
XFormer fails when passing attention mask while using bfloat and key's sequence length not being a multiple of 8
Oct 10, 2024
Sure, feel free to open a PR. Thank you for your help! |
6 tasks
6 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the bug
XFormer will fail when passing attention mask with its last dimension not being a multiple of 8 (i.e. key's sequence length) under bfloat16. This seems to be because xformer needs to attn_bias's stride to be a multiple of 8. Padding the attention mask on its last dimension will solve this problem.
Reproduction
Logs
System Info
Who can help?
@DN6 @yiyixuxu @sayakpaul
The text was updated successfully, but these errors were encountered: