Skip to content

Conversation

@yzh119
Copy link
Contributor

@yzh119 yzh119 commented Jan 23, 2026

This PR implements these features:

  1. add the missing op schema doc for dsa
  2. unify dsa decode & prefill
  3. change return type of dsa and gdn from dict to tuple
  4. fix unittest tolerance.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added DeepSeek Sparse Attention (DSA) paged memory layout documentation
    • Added optional scale parameter to GDN operations
  • Improvements

    • DSA sparse attention now supports both prefill and decode stages
    • Consolidated DSA operation variants into unified implementations
  • Tests

    • Reorganized and simplified test suites for DSA and GDN operations
    • Removed deprecated test configurations

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link

coderabbitai bot commented Jan 23, 2026

📝 Walkthrough

Walkthrough

This PR refactors DeepSeek Sparse Attention (DSA) definitions and tests to unify prefill/decode operations under a single "attention" naming convention, converts reference implementations from dictionary to tuple returns, adds scale parameters to GDN definitions, and removes stage-specific test variants while introducing consolidated test modules with updated reference imports and assertions.

Changes

Cohort / File(s) Summary
Documentation & Schema
CLAUDE.md, docs/op_type_schema/dsa_paged.md
Updated MLA description to "Multi-Head Latent Attention (paged)", added new DSA operation type schema with indexer and sparse_attention variants, and updated attention definitions to include DSA mappings.
DSA Definition Updates (ps1)
flashinfer_trace/definitions/dsa_paged/dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.json
Renamed from dsa_sparse_decode_* to dsa_sparse_attention_*, updated description to support both prefill and decode stages, changed batch_size axis to num_tokens with corresponding shape/constraint updates, refactored reference implementation to iterate over tokens, removed stage:decode tag.
DSA Definition Updates (ps64)
flashinfer_trace/definitions/dsa_paged/dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.json, dsa_topk_indexer_fp8_h64_d128_topk256_ps64.json
Similar renaming and num_tokens refactor as ps1; topk_indexer updated to return tuple (topk_indices) instead of dict {"topk_indices": topk_indices}.
DSA Definition Removals
flashinfer_trace/definitions/dsa_paged/dsa_sparse_prefill_causal_*_ps{1,64}.json
Removed stage-specific prefill definitions (123 and 123 lines deleted respectively).
GDN Definition Updates
flashinfer_trace/definitions/gdn/gdn_decode_qk16_v32_d128_k_last.json, gdn_prefill_qk16_v32_d128_k_last.json
Added optional scale input parameter (shape: null, dtype: float32), updated reference implementations to return tuples (output, new_state) instead of dicts with "output" and "new_state" keys.
DSA Test Removals
flashinfer_trace/tests/references/test_dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps{1,64}.py, test_dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps{1,64}.py
Deleted comprehensive test modules with reference implementations, input generation, and SGLang comparisons (565, 588, 497, 521 lines removed respectively).
DSA Test Additions
flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps{1,64}.py
Added new unified test modules with simplified reference implementations, random input generators, and shape/padding validation tests (192 and 190 lines added).
Test Reference Updates
flashinfer_trace/tests/references/test_dsa_vs_definition_reference.py, test_gdn_decode_qk16_v32_d128_k_last.py, test_gdn_prefill_qk16_v32_d128_k_last.py
Updated definition references to use new naming (dsa_sparse_attention_*), unpacked tuple returns from reference implementations instead of dict access, adjusted tolerances (atol/rtol 5e-3→1e-2), replaced SM90 capability checks with stricter SM90-only guards.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • xslingcn
  • Ubospica

Poem

🐰 Whiskers twitching with delight

Sparse attention blooms so bright,
Unified from decode and prefill's plight,
Tuples replace dicts in graceful flight,
From batch to tokens—dimensions unite! 🌟

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly addresses the main changes: updating DSA and GDN definitions, operator schema, and unit tests. It accurately summarizes the primary objectives of the PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the DeepSeek Sparse Attention (DSA) and Gated Delta Net (GDN) definitions and their associated tests. The core changes involve consolidating DSA's decode and prefill operations into a single, more generalized sparse attention definition, enhancing documentation, and standardizing the return types of reference implementations. These updates improve consistency and maintainability across the codebase.

Highlights

  • DSA Op Schema Documentation: Added comprehensive operation schema documentation for DeepSeek Sparse Attention (DSA) with paged memory layout, detailing its two-stage mechanism (indexer and sparse_attention), axes, inputs, outputs, and constraints.
  • DSA Decode and Prefill Unification: Unified the separate decode and prefill definitions for DSA sparse attention into a single dsa_sparse_attention operation. This involved renaming definition files, updating descriptions, and modifying axis and input/output shapes to be generic for both stages (e.g., batch_size changed to num_tokens). Corresponding prefill-specific definition files and test files were removed.
  • Reference Implementation Return Type Change: Modified the reference implementations for DSA sparse attention, DSA top-K indexer, and GDN (Gated Delta Net) operations to return outputs as tuples instead of dictionaries. This change was propagated to relevant test files.
  • GDN Unittest Tolerance Adjustment: Adjusted the absolute and relative tolerance (atol, rtol) for GDN decode unittests to 1e-2 to accommodate potential numerical differences.
  • Documentation Updates: Updated the main documentation (CLAUDE.md) to include dsa_paged as a supported op_type and clarified the description for mla_paged.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@CLAUDE.md`:
- Line 79: The example in the table uses the wrong definition name; replace the
example value `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` with the actual
JSON definition name `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` (or
vice versa if the JSON should be renamed) and ensure any other occurrences of
`dsa_paged` examples reference
`dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` so the documentation and
JSON definition names are consistent.
🧹 Nitpick comments (7)
flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py (1)

232-232: Consider aligning tolerances between test_correctness and test_gdn_decode_k_last.

test_gdn_decode_k_last now uses atol=1e-2, rtol=1e-2 (line 232), but test_correctness still uses atol=5e-3, rtol=5e-3 (line 120 default parameter). This inconsistency could cause test_correctness to fail for inputs that test_gdn_decode_k_last would pass.

Consider either:

  1. Updating test_correctness to use the same relaxed tolerance
  2. Documenting why different tolerances are appropriate for each test
docs/op_type_schema/dsa_paged.md (1)

50-50: Minor: Clarify "2-based" log-sum-exp.

The term "2-based log-sum-exp" may be unclear to readers. Consider expanding to:

- `lse`: log-sum-exp for online softmax (base-2 logarithm) [num_tokens, num_qo_heads]

Or add a brief note explaining that this uses log base 2 (common in FlashAttention implementations for numerical stability).

flashinfer_trace/tests/references/test_dsa_vs_definition_reference.py (2)

110-110: Prefix unused ref_lse with underscore.

The ref_lse variable is unpacked but never used. Per static analysis (RUF059), prefix it with an underscore to indicate it's intentionally unused.

Proposed fix
-    ref_output, ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale)
+    ref_output, _ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale)

355-355: Prefix unused ref_lse with underscore.

Same issue as in the other test function.

Proposed fix
-    ref_output, ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale)
+    ref_output, _ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale)
flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.py (2)

83-83: Return type inconsistent with JSON definition reference.

This run function returns a dict {"output": output, "lse": lse}, but the corresponding JSON definition (dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.json) returns a tuple (output, lse). Consider aligning return types for consistency across the codebase, especially since the PR objective is to change return types from dict to tuple.

Proposed fix to return tuple
-    return {"output": output, "lse": lse}
+    return output, lse

And update the tests to unpack accordingly:

-    result = run(...)
-    output = result["output"]
-    lse = result["lse"]
+    output, lse = run(...)

31-31: Prefix unused num_pages with underscore.

The num_pages variable is unpacked but never used in this function. Per static analysis (RUF059), prefix it with an underscore.

Proposed fix
-    num_pages, page_size, _ = ckv_cache.shape
+    _num_pages, page_size, _ = ckv_cache.shape
flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.py (1)

86-86: Return type inconsistent with JSON definition reference.

Same issue as the ps64 variant. This run function returns a dict, but the JSON definition returns a tuple. Consider aligning for consistency with the PR's objective to change return types from dict to tuple.

Proposed fix to return tuple
-    return {"output": output, "lse": lse}
+    return output, lse

And update the tests to unpack accordingly.

| `gqa_paged` | Group Query Attention (paged) | `gqa_paged_decode_h32_kv8_d128_ps1` |
| `mla_paged` | Multi-Head Linear Attention | `mla_paged_decode_h16_ckv512_kpe64_ps1` |
| `mla_paged` | Multi-Head Latent Attention (paged) | `mla_paged_decode_h16_ckv512_kpe64_ps1` |
| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` |
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Documentation example doesn't match actual definition name.

The example shows dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1 but the actual JSON definition is named dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1. This inconsistency will confuse users.

Proposed fix
-| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` |
+| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` |
| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` |
🤖 Prompt for AI Agents
In `@CLAUDE.md` at line 79, The example in the table uses the wrong definition
name; replace the example value `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1`
with the actual JSON definition name
`dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` (or vice versa if the JSON
should be renamed) and ensure any other occurrences of `dsa_paged` examples
reference `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` so the
documentation and JSON definition names are consistent.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully implements the unification of DSA decode and prefill operations into a single dsa_sparse_attention definition, along with the addition of its missing op schema documentation. The return types for the reference implementations of DSA and GDN operations have been consistently changed from dictionaries to tuples, improving API consistency. Furthermore, the unittest tolerances for GDN operations have been adjusted, and the GDN prefill tests now correctly specify the required SM90 architecture. Overall, the changes enhance the clarity, correctness, and maintainability of the codebase.

},
"reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return output, lse"
"reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse"
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The change in the return type of the run function from a dictionary to a tuple aligns with the PR's objective to unify return types. This improves consistency across the API.

    return output, lse

},
"reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return {\"output\": output, \"lse\": lse}"
"reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse"
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The change in the return type of the run function from a dictionary to a tuple aligns with the PR's objective to unify return types. This improves consistency across the API.

    return output, lse

}
},
"reference": "import torch\n\n\ndef dequant_fp8_kv_cache(k_index_cache_fp8):\n \"\"\"Dequantize FP8 KV cache from deep_gemm format.\n \n Input: [num_pages, page_size, 1, 132] int8 (interpreted as uint8)\n Memory layout (per page): [fp8_data (page_size * 128 bytes), scales (page_size * 4 bytes)]\n After view to [num_pages, page_size, 1, 132]: NOT directly indexable as [fp8, scale] per token!\n Output: [num_pages, page_size, 128] float32\n \"\"\"\n # View as uint8 for correct byte interpretation\n k_index_cache_fp8 = k_index_cache_fp8.view(torch.uint8)\n num_pages, page_size, num_heads, head_dim_sf = k_index_cache_fp8.shape\n head_dim = head_dim_sf - 4 # 128\n \n # Go back to flat format to reverse the packing\n kv_flat = k_index_cache_fp8.view(num_pages, page_size * head_dim_sf)\n \n # FP8 part: first page_size * head_dim bytes\n fp8_bytes = kv_flat[:, :page_size * head_dim].contiguous()\n fp8_tensor = fp8_bytes.view(num_pages, page_size, head_dim).view(torch.float8_e4m3fn)\n fp8_float = fp8_tensor.to(torch.float32)\n \n # Scale part: last page_size * 4 bytes -> page_size float32 values\n scale_bytes = kv_flat[:, page_size * head_dim:].contiguous()\n scale = scale_bytes.view(num_pages, page_size, 4).view(torch.float32) # [num_pages, page_size, 1]\n \n return fp8_float * scale\n\n\n@torch.no_grad()\ndef run(q_index_fp8, k_index_cache_fp8, weights, seq_lens, block_table):\n batch_size, num_index_heads, index_head_dim = q_index_fp8.shape\n num_pages, page_size, _, _ = k_index_cache_fp8.shape\n topk = 256\n\n # Check constants\n assert num_index_heads == 64\n assert index_head_dim == 128\n assert page_size == 64\n\n device = q_index_fp8.device\n\n # Dequantize inputs\n q = q_index_fp8.to(torch.float32) # [batch, heads, head_dim]\n K_all = dequant_fp8_kv_cache(k_index_cache_fp8) # [num_pages, page_size, head_dim]\n\n topk_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device)\n max_num_pages = block_table.shape[1]\n\n for b in range(batch_size):\n seq_len = int(seq_lens[b].item())\n \n if seq_len == 0:\n continue\n\n # Get pages for this sequence\n num_pages_for_seq = (seq_len + page_size - 1) // page_size\n page_indices = block_table[b, :num_pages_for_seq].to(torch.long)\n \n # Gather K from pages\n K_paged = K_all[page_indices] # [num_pages_for_seq, page_size, head_dim]\n K = K_paged.reshape(-1, index_head_dim)[:seq_len] # [seq_len, head_dim]\n \n # Query for this batch element\n q_b = q[b] # [num_heads, head_dim]\n \n # Compute attention scores\n scores = q_b @ K.T # [num_heads, seq_len]\n \n # Apply ReLU (deep_gemm uses ReLU activation)\n scores_relu = torch.relu(scores) # [num_heads, seq_len]\n \n # Apply learned weights and sum across heads\n w = weights[b] # [num_heads]\n weighted_scores = scores_relu * w[:, None] # [num_heads, seq_len]\n final_scores = weighted_scores.sum(dim=0) # [seq_len]\n \n # Select top-K\n actual_topk = min(topk, seq_len)\n _, topk_idx = torch.topk(final_scores, actual_topk)\n \n # Convert to global token indices\n # Token index = page_idx * page_size + offset_in_page\n page_idx_per_token = topk_idx // page_size\n offset_per_token = topk_idx % page_size\n global_page_idx = page_indices[page_idx_per_token]\n topk_tokens = global_page_idx * page_size + offset_per_token\n \n topk_indices[b, :actual_topk] = topk_tokens.to(torch.int32)\n\n return {\"topk_indices\": topk_indices}"
"reference": "import torch\n\n\ndef dequant_fp8_kv_cache(k_index_cache_fp8):\n \"\"\"Dequantize FP8 KV cache from deep_gemm format.\n \n Input: [num_pages, page_size, 1, 132] int8 (interpreted as uint8)\n Memory layout (per page): [fp8_data (page_size * 128 bytes), scales (page_size * 4 bytes)]\n After view to [num_pages, page_size, 1, 132]: NOT directly indexable as [fp8, scale] per token!\n Output: [num_pages, page_size, 128] float32\n \"\"\"\n # View as uint8 for correct byte interpretation\n k_index_cache_fp8 = k_index_cache_fp8.view(torch.uint8)\n num_pages, page_size, num_heads, head_dim_sf = k_index_cache_fp8.shape\n head_dim = head_dim_sf - 4 # 128\n \n # Go back to flat format to reverse the packing\n kv_flat = k_index_cache_fp8.view(num_pages, page_size * head_dim_sf)\n \n # FP8 part: first page_size * head_dim bytes\n fp8_bytes = kv_flat[:, :page_size * head_dim].contiguous()\n fp8_tensor = fp8_bytes.view(num_pages, page_size, head_dim).view(torch.float8_e4m3fn)\n fp8_float = fp8_tensor.to(torch.float32)\n \n # Scale part: last page_size * 4 bytes -> page_size float32 values\n scale_bytes = kv_flat[:, page_size * head_dim:].contiguous()\n scale = scale_bytes.view(num_pages, page_size, 4).view(torch.float32) # [num_pages, page_size, 1]\n \n return fp8_float * scale\n\n\n@torch.no_grad()\ndef run(q_index_fp8, k_index_cache_fp8, weights, seq_lens, block_table):\n batch_size, num_index_heads, index_head_dim = q_index_fp8.shape\n num_pages, page_size, _, _ = k_index_cache_fp8.shape\n topk = 256\n\n # Check constants\n assert num_index_heads == 64\n assert index_head_dim == 128\n assert page_size == 64\n\n device = q_index_fp8.device\n\n # Dequantize inputs\n q = q_index_fp8.to(torch.float32) # [batch, heads, head_dim]\n K_all = dequant_fp8_kv_cache(k_index_cache_fp8) # [num_pages, page_size, head_dim]\n\n topk_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device)\n max_num_pages = block_table.shape[1]\n\n for b in range(batch_size):\n seq_len = int(seq_lens[b].item())\n \n if seq_len == 0:\n continue\n\n # Get pages for this sequence\n num_pages_for_seq = (seq_len + page_size - 1) // page_size\n page_indices = block_table[b, :num_pages_for_seq].to(torch.long)\n \n # Gather K from pages\n K_paged = K_all[page_indices] # [num_pages_for_seq, page_size, head_dim]\n K = K_paged.reshape(-1, index_head_dim)[:seq_len] # [seq_len, head_dim]\n \n # Query for this batch element\n q_b = q[b] # [num_heads, head_dim]\n \n # Compute attention scores\n scores = q_b @ K.T # [num_heads, seq_len]\n \n # Apply ReLU (deep_gemm uses ReLU activation)\n scores_relu = torch.relu(scores) # [num_heads, seq_len]\n \n # Apply learned weights and sum across heads\n w = weights[b] # [num_heads]\n weighted_scores = scores_relu * w[:, None] # [num_heads, seq_len]\n final_scores = weighted_scores.sum(dim=0) # [seq_len]\n \n # Select top-K\n actual_topk = min(topk, seq_len)\n _, topk_idx = torch.topk(final_scores, actual_topk)\n \n # Convert to global token indices\n # Token index = page_idx * page_size + offset_in_page\n page_idx_per_token = topk_idx // page_size\n offset_per_token = topk_idx % page_size\n global_page_idx = page_indices[page_idx_per_token]\n topk_tokens = global_page_idx * page_size + offset_per_token\n \n topk_indices[b, :actual_topk] = topk_tokens.to(torch.int32)\n\n return (topk_indices,)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type of the run function has been changed from a dictionary to a tuple, which is consistent with the PR's goal of unifying return types for DSA and GDN definitions.

    return (topk_indices,)

Comment on lines +148 to 149
"reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation (k-last layout).\n \n State layout: [B, H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [B, 1, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, 1, HV]\n beta = torch.sigmoid(b.float()) # [B, 1, HV]\n \n q_f32 = q.squeeze(1).float()\n k_f32 = k.squeeze(1).float()\n v_f32 = v.squeeze(1).float()\n g_f32 = g.squeeze(1).float()\n beta_f32 = beta.squeeze(1).float()\n \n if state is not None:\n state_f32 = state.float()\n else:\n state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)\n \n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n \n new_state = torch.zeros_like(state_f32)\n output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)\n \n for b_idx in range(B):\n for h_idx in range(num_heads):\n q_h = q_exp[b_idx, h_idx]\n k_h = k_exp[b_idx, h_idx]\n v_h = v_f32[b_idx, h_idx]\n h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2) # [V,K] -> [K,V]\n g_val = g_f32[b_idx, h_idx]\n beta_val = beta_f32[b_idx, h_idx]\n \n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n \n output[b_idx, h_idx] = scale * (q_h @ h_state)\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2) # [K,V] -> [V,K]\n \n output = output.unsqueeze(1).to(torch.bfloat16)\n return output, new_state"
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type of the run function has been changed from a dictionary to a tuple, which is consistent with the PR's goal of unifying return types for DSA and GDN definitions.

    return output, new_state

Comment on lines +149 to 150
"reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, cu_seqlens, scale):\n \"\"\"\n Gated Delta Net prefill reference implementation (k-last layout).\n \n State layout: [H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n total_seq_len, num_q_heads, head_size = q.shape\n num_v_heads = v.shape[1]\n num_k_heads = k.shape[1]\n num_sab_heads = max(num_q_heads, num_v_heads)\n num_seqs = cu_seqlens.size(0) - 1\n device = q.device\n\n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert head_size == 128\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(head_size)\n\n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [total_seq_len, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [total_seq_len, HV]\n beta = torch.sigmoid(b.float()) # [total_seq_len, HV]\n\n q_exp = q.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n\n output = torch.zeros(\n (total_seq_len, num_sab_heads, head_size), dtype=torch.bfloat16, device=device\n )\n new_state = torch.zeros(\n (num_seqs, num_sab_heads, head_size, head_size), dtype=torch.float32, device=device\n )\n\n for seq_idx in range(num_seqs):\n seq_start = int(cu_seqlens[seq_idx].item())\n seq_end = int(cu_seqlens[seq_idx + 1].item())\n seq_len = seq_end - seq_start\n\n if seq_len <= 0:\n continue\n\n if state is not None:\n state_HKV = state[seq_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V]\n else:\n state_HKV = torch.zeros(\n (num_sab_heads, head_size, head_size), dtype=torch.float32, device=device\n )\n\n for i in range(seq_len):\n t = seq_start + i\n q_H1K = q_exp[t].unsqueeze(1).float()\n k_H1K = k_exp[t].unsqueeze(1).float()\n v_H1V = v[t].unsqueeze(1).float()\n g_H11 = g[t].unsqueeze(1).unsqueeze(2)\n beta_H11 = beta[t].unsqueeze(1).unsqueeze(2)\n\n old_state_HKV = g_H11 * state_HKV\n old_v_H1V = matmul(k_H1K, old_state_HKV)\n new_v_H1V = beta_H11 * v_H1V + (1 - beta_H11) * old_v_H1V\n state_remove = torch.einsum('hkl,hlv->hkv', k_H1K.transpose(-1, -2), old_v_H1V)\n state_update = torch.einsum('hkl,hlv->hkv', k_H1K.transpose(-1, -2), new_v_H1V)\n state_HKV = old_state_HKV - state_remove + state_update\n\n o_H1V = scale * matmul(q_H1K, state_HKV)\n output[t] = o_H1V.squeeze(1).to(torch.bfloat16)\n\n new_state[seq_idx] = state_HKV.transpose(-1, -2) # [H,K,V] -> [H,V,K]\n\n return output, new_state"
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type of the run function has been changed from a dictionary to a tuple, which is consistent with the PR's goal of unifying return types for DSA and GDN definitions.

    return output, new_state

Copy link
Collaborator

@Ubospica Ubospica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

| Operation Type | Description | Example |
|---------------|-------------|---------|
| `rmsnorm` | RMS Layer Normalization | `rmsnorm_h4096` |
| `gemm` | General Matrix Multiplication | `gemm_n_6144_k_4096` |
Copy link
Collaborator

@Ubospica Ubospica Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to embed the exact op types in CLAUDE.md? I think we can just provide a general introduction here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, they are automatic edit by claude code itself, we should make the CLAUDE.md more simple (not necessarily in this PR).

@yzh119 yzh119 merged commit 24f77b7 into main Jan 23, 2026
22 checks passed
@yzh119 yzh119 deleted the add-op-schema-dsa branch January 23, 2026 09:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants