Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XFormer fails when passing attention mask while using bfloat and key's sequence length not being a multiple of 8 #9637

Open
dhmbb2 opened this issue Oct 10, 2024 · 2 comments · May be fixed by #9678
Labels
bug Something isn't working

Comments

@dhmbb2
Copy link

dhmbb2 commented Oct 10, 2024

Describe the bug

XFormer will fail when passing attention mask with its last dimension not being a multiple of 8 (i.e. key's sequence length) under bfloat16. This seems to be because xformer needs to attn_bias's stride to be a multiple of 8. Padding the attention mask on its last dimension will solve this problem.

Reproduction

from diffusers.models.attention_processor import Attention, XFormersAttnProcessor
import torch

attn_processer = XFormersAttnProcessor()

attn = Attention(
        query_dim=256,
        heads=8,
        dim_head=64,
        processor=attn_processer,
    ).to(device="cuda", dtype=torch.bfloat16)

q = torch.zeros((2, 350, 256), device="cuda", dtype=torch.bfloat16)
kv = torch.zeros((2, 700, 256), device="cuda", dtype=torch.bfloat16)
attn_mask = torch.zeros((2, 1, 700), device="cuda", dtype=torch.bfloat16)

out = attn(q, kv, attn_mask)

Logs

Traceback (most recent call last):
  File "/mnt/hcufs/home/youjunqi/GENAD2/test/test_diffusers.py", line 17, in <module>
    out = attn(q, kv, attn_mask)
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/diffusers/models/attention_processor.py", line 490, in forward
    return self.processor(
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/diffusers/models/attention_processor.py", line 2156, in __call__
    hidden_states = xformers.ops.memory_efficient_attention(
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/xformers/ops/fmha/__init__.py", line 276, in memory_efficient_attention
    return _memory_efficient_attention(
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/xformers/ops/fmha/__init__.py", line 403, in _memory_efficient_attention
    return _fMHA.apply(
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/xformers/ops/fmha/__init__.py", line 74, in forward
    out, op_ctx = _memory_efficient_attention_forward_requires_grad(
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/xformers/ops/fmha/__init__.py", line 428, in _memory_efficient_attention_forward_requires_grad
    op = _dispatch_fw(inp, True)
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/xformers/ops/fmha/dispatch.py", line 119, in _dispatch_fw
    return _run_priority_list(
  File "/mnt/hcufs/home/youjunqi/miniconda3/envs/nuplan/lib/python3.9/site-packages/xformers/ops/fmha/dispatch.py", line 55, in _run_priority_list
    raise NotImplementedError(msg)
NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(16, 350, 1, 64) (torch.bfloat16)
     key         : shape=(16, 700, 1, 64) (torch.bfloat16)
     value       : shape=(16, 700, 1, 64) (torch.bfloat16)
     attn_bias   : <class 'torch.Tensor'>
     p           : 0.0
`[email protected]` is not supported because:
    attn_bias type is <class 'torch.Tensor'>
`cutlassF-pt` is not supported because:
    attn_bias.stride(-2) % 8 != 0 (attn_bias.stride() = (700, 0, 0, 1))
    HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, you need to ensure memory is aligned by slicing a bigger tensor. Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`
`smallkF` is not supported because:
    max(query.shape[-1] != value.shape[-1]) > 32
    dtype=torch.bfloat16 (supported: {torch.float32})
    has custom scale
    unsupported embed per head: 64

System Info

  • 🤗 Diffusers version: 0.30.2
  • Platform: Linux-5.15.0-41-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.9.19
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.24.6
  • Transformers version: 4.44.2
  • Accelerate version: 0.21.0
  • PEFT version: 0.12.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.4
  • xFormers version: 0.0.27.post2
  • Accelerator: NVIDIA H800, 81559 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@DN6 @yiyixuxu @sayakpaul

@dhmbb2 dhmbb2 added the bug Something isn't working label Oct 10, 2024
@dhmbb2
Copy link
Author

dhmbb2 commented Oct 10, 2024

solve this by adding the below code to XFormersAttnProcessor

if attention_mask is not None:
            # expand our mask's singleton query_tokens dimension:
            #   [batch*heads,            1, key_tokens] ->
            #   [batch*heads, query_tokens, key_tokens]
            # so that it can be added as a bias onto the attention scores that xformers computes:
            #   [batch*heads, query_tokens, key_tokens]
            # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
            _, query_tokens, _ = hidden_states.shape
+            if attention_mask.dtype == torch.bfloat16 and attention_mask.shape[-1] % 8 != 0:
+               mask = torch.zeros(
+                    (attention_mask.shape[0], attention_mask.shape[1], math.ceil(attention_mask.shape[-1] / 8) * 8),
+                   device=attention_mask.device,
+                    dtype=attention_mask.dtype
+                )
+                mask[:, :, :attention_mask.shape[-1]] = attention_mask
+                attention_mask = mask[:, :, :attention_mask.shape[-1]]
            attention_mask = attention_mask.expand(-1, query_tokens, -1)

Don't know if a pr is needed about this🤔.

@dhmbb2 dhmbb2 changed the title XFormer fails when passing attention mask while using bfloat and key's sequence length is not a multiple of 8 XFormer fails when passing attention mask while using bfloat and key's sequence length not being a multiple of 8 Oct 10, 2024
@sayakpaul
Copy link
Member

Sure, feel free to open a PR. Thank you for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants