Skip to content

fix: flash-attn-3 error#312

Open
Explorer-Dong wants to merge 1 commit intoWan-Video:mainfrom
Explorer-Dong:main
Open

fix: flash-attn-3 error#312
Explorer-Dong wants to merge 1 commit intoWan-Video:mainfrom
Explorer-Dong:main

Conversation

@Explorer-Dong
Copy link

Hi Wan Team!

Recently, I use flash-attention-3 to accelerate inference, but I meet the error below:

...
[rank7]:   File "/codespace/text-or-image-to-video-model/src/wan/modules/model.py", line 243, in forward
[rank7]:     y = self.self_attn(
[rank7]:         ^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank7]:     return self._call_impl(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank7]:     return forward_call(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/codespace/text-or-image-to-video-model/src/wan/distributed/sequence_parallel.py", line 165, in sp_attn_forward
[rank7]:     x = distributed_attention(
[rank7]:         ^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/codespace/text-or-image-to-video-model/src/wan/distributed/ulysses.py", line 37, in distributed_attention
[rank7]:     x = flash_attention(
[rank7]:         ^^^^^^^^^^^^^^^^
[rank7]:   File "/codespace/text-or-image-to-video-model/src/wan/modules/attention.py", line 110, in flash_attention
[rank7]:     deterministic=deterministic)[0].unflatten(0, (b, lq))
[rank7]:                                     ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/torch/_tensor.py", line 1432, in unflatten
[rank7]:     return super().unflatten(dim, sizes)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: RuntimeError: unflatten: Provided sizes [1, 75600] don't multiply up to the size of dim 0 (5) in the input tensor
...

Obviously, it's an attention error. I check the attention code in wan/modules/attention.py, and find a controversial code:

x = flash_attn_interface.flash_attn_varlen_func(
    q=q,
    k=k,
    v=v,
    cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
        0, dtype=torch.int32).to(q.device, non_blocking=True),
    cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
        0, dtype=torch.int32).to(q.device, non_blocking=True),
    seqused_q=None,
    seqused_k=None,
    max_seqlen_q=lq,
    max_seqlen_k=lk,
    softmax_scale=softmax_scale,
    causal=causal,
    deterministic=deterministic)[0].unflatten(0, (b, lq))

I check the source code of flash-attention-3, the function flash_attn_interface.flash_attn_varlen_func() only return tuple under the following condition:

return out if not return_softmax else (out, softmax_lse, S_dmask)

return_softmax only works under the following condition:

return_softmax=return_softmax and dropout_p > 0,

However, these two parameters return_softmax and dropout_p are not included in your code.

So, we only need flash_attn_interface.flash_attn_varlen_func() return out, which is a torch.Tensor.

It works for me.

@Explorer-Dong Explorer-Dong mentioned this pull request Jan 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant