Skip to content

Comments

Fix flash attention for non-pow2 dims, large dims, and backward NaN#7

Open
murrellb wants to merge 1 commit intomainfrom
fix-flash-attention-padding
Open

Fix flash attention for non-pow2 dims, large dims, and backward NaN#7
murrellb wants to merge 1 commit intomainfrom
fix-flash-attention-padding

Conversation

@murrellb
Copy link
Member

  • Support non-pow2 emb_dim via kernel-level zero-padding (padded_dim)
  • Add smaller groupsizes (8, 4) for large dims exceeding 48KB shmem
  • Clamp MMA tile sizes (TM ≤ BM, TN ≤ BN) to prevent OOB shmem access
  • Fix backward NaN: mask OOB K positions in softmax reconstruction to prevent exp(0-m_i)→Inf then Inf*0→NaN contaminating dQ
  • Guard preprocess inv(ls) against ls=0 division

- Support non-pow2 emb_dim via kernel-level zero-padding (padded_dim)
- Add smaller groupsizes (8, 4) for large dims exceeding 48KB shmem
- Clamp MMA tile sizes (TM ≤ BM, TN ≤ BN) to prevent OOB shmem access
- Fix backward NaN: mask OOB K positions in softmax reconstruction
  to prevent exp(0-m_i)→Inf then Inf*0→NaN contaminating dQ
- Guard preprocess inv(ls) against ls=0 division

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant