-
Notifications
You must be signed in to change notification settings - Fork 20
update dsa and gdn definitions, op schema and unittests #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughThis PR refactors DeepSeek Sparse Attention (DSA) definitions and tests to unify prefill/decode operations under a single "attention" naming convention, converts reference implementations from dictionary to tuple returns, adds scale parameters to GDN definitions, and removes stage-specific test variants while introducing consolidated test modules with updated reference imports and assertions. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the DeepSeek Sparse Attention (DSA) and Gated Delta Net (GDN) definitions and their associated tests. The core changes involve consolidating DSA's decode and prefill operations into a single, more generalized sparse attention definition, enhancing documentation, and standardizing the return types of reference implementations. These updates improve consistency and maintainability across the codebase. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@CLAUDE.md`:
- Line 79: The example in the table uses the wrong definition name; replace the
example value `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` with the actual
JSON definition name `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` (or
vice versa if the JSON should be renamed) and ensure any other occurrences of
`dsa_paged` examples reference
`dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` so the documentation and
JSON definition names are consistent.
🧹 Nitpick comments (7)
flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py (1)
232-232: Consider aligning tolerances betweentest_correctnessandtest_gdn_decode_k_last.
test_gdn_decode_k_lastnow usesatol=1e-2, rtol=1e-2(line 232), buttest_correctnessstill usesatol=5e-3, rtol=5e-3(line 120 default parameter). This inconsistency could causetest_correctnessto fail for inputs thattest_gdn_decode_k_lastwould pass.Consider either:
- Updating
test_correctnessto use the same relaxed tolerance- Documenting why different tolerances are appropriate for each test
docs/op_type_schema/dsa_paged.md (1)
50-50: Minor: Clarify "2-based" log-sum-exp.The term "2-based log-sum-exp" may be unclear to readers. Consider expanding to:
- `lse`: log-sum-exp for online softmax (base-2 logarithm) [num_tokens, num_qo_heads]Or add a brief note explaining that this uses log base 2 (common in FlashAttention implementations for numerical stability).
flashinfer_trace/tests/references/test_dsa_vs_definition_reference.py (2)
110-110: Prefix unusedref_lsewith underscore.The
ref_lsevariable is unpacked but never used. Per static analysis (RUF059), prefix it with an underscore to indicate it's intentionally unused.Proposed fix
- ref_output, ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) + ref_output, _ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale)
355-355: Prefix unusedref_lsewith underscore.Same issue as in the other test function.
Proposed fix
- ref_output, ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) + ref_output, _ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale)flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.py (2)
83-83: Return type inconsistent with JSON definition reference.This
runfunction returns a dict{"output": output, "lse": lse}, but the corresponding JSON definition (dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.json) returns a tuple(output, lse). Consider aligning return types for consistency across the codebase, especially since the PR objective is to change return types from dict to tuple.Proposed fix to return tuple
- return {"output": output, "lse": lse} + return output, lseAnd update the tests to unpack accordingly:
- result = run(...) - output = result["output"] - lse = result["lse"] + output, lse = run(...)
31-31: Prefix unusednum_pageswith underscore.The
num_pagesvariable is unpacked but never used in this function. Per static analysis (RUF059), prefix it with an underscore.Proposed fix
- num_pages, page_size, _ = ckv_cache.shape + _num_pages, page_size, _ = ckv_cache.shapeflashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.py (1)
86-86: Return type inconsistent with JSON definition reference.Same issue as the ps64 variant. This
runfunction returns a dict, but the JSON definition returns a tuple. Consider aligning for consistency with the PR's objective to change return types from dict to tuple.Proposed fix to return tuple
- return {"output": output, "lse": lse} + return output, lseAnd update the tests to unpack accordingly.
| | `gqa_paged` | Group Query Attention (paged) | `gqa_paged_decode_h32_kv8_d128_ps1` | | ||
| | `mla_paged` | Multi-Head Linear Attention | `mla_paged_decode_h16_ckv512_kpe64_ps1` | | ||
| | `mla_paged` | Multi-Head Latent Attention (paged) | `mla_paged_decode_h16_ckv512_kpe64_ps1` | | ||
| | `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documentation example doesn't match actual definition name.
The example shows dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1 but the actual JSON definition is named dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1. This inconsistency will confuse users.
Proposed fix
-| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` |
+| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` |📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| | `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` | | |
| | `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` | |
🤖 Prompt for AI Agents
In `@CLAUDE.md` at line 79, The example in the table uses the wrong definition
name; replace the example value `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1`
with the actual JSON definition name
`dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` (or vice versa if the JSON
should be renamed) and ensure any other occurrences of `dsa_paged` examples
reference `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` so the
documentation and JSON definition names are consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request successfully implements the unification of DSA decode and prefill operations into a single dsa_sparse_attention definition, along with the addition of its missing op schema documentation. The return types for the reference implementations of DSA and GDN operations have been consistently changed from dictionaries to tuples, improving API consistency. Furthermore, the unittest tolerances for GDN operations have been adjusted, and the GDN prefill tests now correctly specify the required SM90 architecture. Overall, the changes enhance the clarity, correctness, and maintainability of the codebase.
| }, | ||
| "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return output, lse" | ||
| "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse" | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| }, | ||
| "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return {\"output\": output, \"lse\": lse}" | ||
| "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse" | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| } | ||
| }, | ||
| "reference": "import torch\n\n\ndef dequant_fp8_kv_cache(k_index_cache_fp8):\n \"\"\"Dequantize FP8 KV cache from deep_gemm format.\n \n Input: [num_pages, page_size, 1, 132] int8 (interpreted as uint8)\n Memory layout (per page): [fp8_data (page_size * 128 bytes), scales (page_size * 4 bytes)]\n After view to [num_pages, page_size, 1, 132]: NOT directly indexable as [fp8, scale] per token!\n Output: [num_pages, page_size, 128] float32\n \"\"\"\n # View as uint8 for correct byte interpretation\n k_index_cache_fp8 = k_index_cache_fp8.view(torch.uint8)\n num_pages, page_size, num_heads, head_dim_sf = k_index_cache_fp8.shape\n head_dim = head_dim_sf - 4 # 128\n \n # Go back to flat format to reverse the packing\n kv_flat = k_index_cache_fp8.view(num_pages, page_size * head_dim_sf)\n \n # FP8 part: first page_size * head_dim bytes\n fp8_bytes = kv_flat[:, :page_size * head_dim].contiguous()\n fp8_tensor = fp8_bytes.view(num_pages, page_size, head_dim).view(torch.float8_e4m3fn)\n fp8_float = fp8_tensor.to(torch.float32)\n \n # Scale part: last page_size * 4 bytes -> page_size float32 values\n scale_bytes = kv_flat[:, page_size * head_dim:].contiguous()\n scale = scale_bytes.view(num_pages, page_size, 4).view(torch.float32) # [num_pages, page_size, 1]\n \n return fp8_float * scale\n\n\n@torch.no_grad()\ndef run(q_index_fp8, k_index_cache_fp8, weights, seq_lens, block_table):\n batch_size, num_index_heads, index_head_dim = q_index_fp8.shape\n num_pages, page_size, _, _ = k_index_cache_fp8.shape\n topk = 256\n\n # Check constants\n assert num_index_heads == 64\n assert index_head_dim == 128\n assert page_size == 64\n\n device = q_index_fp8.device\n\n # Dequantize inputs\n q = q_index_fp8.to(torch.float32) # [batch, heads, head_dim]\n K_all = dequant_fp8_kv_cache(k_index_cache_fp8) # [num_pages, page_size, head_dim]\n\n topk_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device)\n max_num_pages = block_table.shape[1]\n\n for b in range(batch_size):\n seq_len = int(seq_lens[b].item())\n \n if seq_len == 0:\n continue\n\n # Get pages for this sequence\n num_pages_for_seq = (seq_len + page_size - 1) // page_size\n page_indices = block_table[b, :num_pages_for_seq].to(torch.long)\n \n # Gather K from pages\n K_paged = K_all[page_indices] # [num_pages_for_seq, page_size, head_dim]\n K = K_paged.reshape(-1, index_head_dim)[:seq_len] # [seq_len, head_dim]\n \n # Query for this batch element\n q_b = q[b] # [num_heads, head_dim]\n \n # Compute attention scores\n scores = q_b @ K.T # [num_heads, seq_len]\n \n # Apply ReLU (deep_gemm uses ReLU activation)\n scores_relu = torch.relu(scores) # [num_heads, seq_len]\n \n # Apply learned weights and sum across heads\n w = weights[b] # [num_heads]\n weighted_scores = scores_relu * w[:, None] # [num_heads, seq_len]\n final_scores = weighted_scores.sum(dim=0) # [seq_len]\n \n # Select top-K\n actual_topk = min(topk, seq_len)\n _, topk_idx = torch.topk(final_scores, actual_topk)\n \n # Convert to global token indices\n # Token index = page_idx * page_size + offset_in_page\n page_idx_per_token = topk_idx // page_size\n offset_per_token = topk_idx % page_size\n global_page_idx = page_indices[page_idx_per_token]\n topk_tokens = global_page_idx * page_size + offset_per_token\n \n topk_indices[b, :actual_topk] = topk_tokens.to(torch.int32)\n\n return {\"topk_indices\": topk_indices}" | ||
| "reference": "import torch\n\n\ndef dequant_fp8_kv_cache(k_index_cache_fp8):\n \"\"\"Dequantize FP8 KV cache from deep_gemm format.\n \n Input: [num_pages, page_size, 1, 132] int8 (interpreted as uint8)\n Memory layout (per page): [fp8_data (page_size * 128 bytes), scales (page_size * 4 bytes)]\n After view to [num_pages, page_size, 1, 132]: NOT directly indexable as [fp8, scale] per token!\n Output: [num_pages, page_size, 128] float32\n \"\"\"\n # View as uint8 for correct byte interpretation\n k_index_cache_fp8 = k_index_cache_fp8.view(torch.uint8)\n num_pages, page_size, num_heads, head_dim_sf = k_index_cache_fp8.shape\n head_dim = head_dim_sf - 4 # 128\n \n # Go back to flat format to reverse the packing\n kv_flat = k_index_cache_fp8.view(num_pages, page_size * head_dim_sf)\n \n # FP8 part: first page_size * head_dim bytes\n fp8_bytes = kv_flat[:, :page_size * head_dim].contiguous()\n fp8_tensor = fp8_bytes.view(num_pages, page_size, head_dim).view(torch.float8_e4m3fn)\n fp8_float = fp8_tensor.to(torch.float32)\n \n # Scale part: last page_size * 4 bytes -> page_size float32 values\n scale_bytes = kv_flat[:, page_size * head_dim:].contiguous()\n scale = scale_bytes.view(num_pages, page_size, 4).view(torch.float32) # [num_pages, page_size, 1]\n \n return fp8_float * scale\n\n\n@torch.no_grad()\ndef run(q_index_fp8, k_index_cache_fp8, weights, seq_lens, block_table):\n batch_size, num_index_heads, index_head_dim = q_index_fp8.shape\n num_pages, page_size, _, _ = k_index_cache_fp8.shape\n topk = 256\n\n # Check constants\n assert num_index_heads == 64\n assert index_head_dim == 128\n assert page_size == 64\n\n device = q_index_fp8.device\n\n # Dequantize inputs\n q = q_index_fp8.to(torch.float32) # [batch, heads, head_dim]\n K_all = dequant_fp8_kv_cache(k_index_cache_fp8) # [num_pages, page_size, head_dim]\n\n topk_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device)\n max_num_pages = block_table.shape[1]\n\n for b in range(batch_size):\n seq_len = int(seq_lens[b].item())\n \n if seq_len == 0:\n continue\n\n # Get pages for this sequence\n num_pages_for_seq = (seq_len + page_size - 1) // page_size\n page_indices = block_table[b, :num_pages_for_seq].to(torch.long)\n \n # Gather K from pages\n K_paged = K_all[page_indices] # [num_pages_for_seq, page_size, head_dim]\n K = K_paged.reshape(-1, index_head_dim)[:seq_len] # [seq_len, head_dim]\n \n # Query for this batch element\n q_b = q[b] # [num_heads, head_dim]\n \n # Compute attention scores\n scores = q_b @ K.T # [num_heads, seq_len]\n \n # Apply ReLU (deep_gemm uses ReLU activation)\n scores_relu = torch.relu(scores) # [num_heads, seq_len]\n \n # Apply learned weights and sum across heads\n w = weights[b] # [num_heads]\n weighted_scores = scores_relu * w[:, None] # [num_heads, seq_len]\n final_scores = weighted_scores.sum(dim=0) # [seq_len]\n \n # Select top-K\n actual_topk = min(topk, seq_len)\n _, topk_idx = torch.topk(final_scores, actual_topk)\n \n # Convert to global token indices\n # Token index = page_idx * page_size + offset_in_page\n page_idx_per_token = topk_idx // page_size\n offset_per_token = topk_idx % page_size\n global_page_idx = page_indices[page_idx_per_token]\n topk_tokens = global_page_idx * page_size + offset_per_token\n \n topk_indices[b, :actual_topk] = topk_tokens.to(torch.int32)\n\n return (topk_indices,)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation (k-last layout).\n \n State layout: [B, H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [B, 1, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, 1, HV]\n beta = torch.sigmoid(b.float()) # [B, 1, HV]\n \n q_f32 = q.squeeze(1).float()\n k_f32 = k.squeeze(1).float()\n v_f32 = v.squeeze(1).float()\n g_f32 = g.squeeze(1).float()\n beta_f32 = beta.squeeze(1).float()\n \n if state is not None:\n state_f32 = state.float()\n else:\n state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)\n \n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n \n new_state = torch.zeros_like(state_f32)\n output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)\n \n for b_idx in range(B):\n for h_idx in range(num_heads):\n q_h = q_exp[b_idx, h_idx]\n k_h = k_exp[b_idx, h_idx]\n v_h = v_f32[b_idx, h_idx]\n h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2) # [V,K] -> [K,V]\n g_val = g_f32[b_idx, h_idx]\n beta_val = beta_f32[b_idx, h_idx]\n \n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n \n output[b_idx, h_idx] = scale * (q_h @ h_state)\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2) # [K,V] -> [V,K]\n \n output = output.unsqueeze(1).to(torch.bfloat16)\n return output, new_state" | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, cu_seqlens, scale):\n \"\"\"\n Gated Delta Net prefill reference implementation (k-last layout).\n \n State layout: [H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n total_seq_len, num_q_heads, head_size = q.shape\n num_v_heads = v.shape[1]\n num_k_heads = k.shape[1]\n num_sab_heads = max(num_q_heads, num_v_heads)\n num_seqs = cu_seqlens.size(0) - 1\n device = q.device\n\n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert head_size == 128\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(head_size)\n\n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [total_seq_len, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [total_seq_len, HV]\n beta = torch.sigmoid(b.float()) # [total_seq_len, HV]\n\n q_exp = q.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n\n output = torch.zeros(\n (total_seq_len, num_sab_heads, head_size), dtype=torch.bfloat16, device=device\n )\n new_state = torch.zeros(\n (num_seqs, num_sab_heads, head_size, head_size), dtype=torch.float32, device=device\n )\n\n for seq_idx in range(num_seqs):\n seq_start = int(cu_seqlens[seq_idx].item())\n seq_end = int(cu_seqlens[seq_idx + 1].item())\n seq_len = seq_end - seq_start\n\n if seq_len <= 0:\n continue\n\n if state is not None:\n state_HKV = state[seq_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V]\n else:\n state_HKV = torch.zeros(\n (num_sab_heads, head_size, head_size), dtype=torch.float32, device=device\n )\n\n for i in range(seq_len):\n t = seq_start + i\n q_H1K = q_exp[t].unsqueeze(1).float()\n k_H1K = k_exp[t].unsqueeze(1).float()\n v_H1V = v[t].unsqueeze(1).float()\n g_H11 = g[t].unsqueeze(1).unsqueeze(2)\n beta_H11 = beta[t].unsqueeze(1).unsqueeze(2)\n\n old_state_HKV = g_H11 * state_HKV\n old_v_H1V = matmul(k_H1K, old_state_HKV)\n new_v_H1V = beta_H11 * v_H1V + (1 - beta_H11) * old_v_H1V\n state_remove = torch.einsum('hkl,hlv->hkv', k_H1K.transpose(-1, -2), old_v_H1V)\n state_update = torch.einsum('hkl,hlv->hkv', k_H1K.transpose(-1, -2), new_v_H1V)\n state_HKV = old_state_HKV - state_remove + state_update\n\n o_H1V = scale * matmul(q_H1K, state_HKV)\n output[t] = o_H1V.squeeze(1).to(torch.bfloat16)\n\n new_state[seq_idx] = state_HKV.transpose(-1, -2) # [H,K,V] -> [H,V,K]\n\n return output, new_state" | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ubospica
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
| | Operation Type | Description | Example | | ||
| |---------------|-------------|---------| | ||
| | `rmsnorm` | RMS Layer Normalization | `rmsnorm_h4096` | | ||
| | `gemm` | General Matrix Multiplication | `gemm_n_6144_k_4096` | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need to embed the exact op types in CLAUDE.md? I think we can just provide a general introduction here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, they are automatic edit by claude code itself, we should make the CLAUDE.md more simple (not necessarily in this PR).
This PR implements these features:
Summary by CodeRabbit
Release Notes
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.