Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ FlashInfer-Bench supports the following op_types (corresponding to different Def
| `gemm` | General Matrix Multiplication | `gemm_n_6144_k_4096` |
Copy link
Collaborator

@Ubospica Ubospica Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to embed the exact op types in CLAUDE.md? I think we can just provide a general introduction here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, they are automatic edit by claude code itself, we should make the CLAUDE.md more simple (not necessarily in this PR).

| `gqa_ragged` | Group Query Attention (ragged) | `gqa_ragged_prefill_causal_h32_kv8_d128` |
| `gqa_paged` | Group Query Attention (paged) | `gqa_paged_decode_h32_kv8_d128_ps1` |
| `mla_paged` | Multi-Head Linear Attention | `mla_paged_decode_h16_ckv512_kpe64_ps1` |
| `mla_paged` | Multi-Head Latent Attention (paged) | `mla_paged_decode_h16_ckv512_kpe64_ps1` |
| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` |
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Documentation example doesn't match actual definition name.

The example shows dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1 but the actual JSON definition is named dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1. This inconsistency will confuse users.

Proposed fix
-| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` |
+| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` |
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` |
| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` |
🤖 Prompt for AI Agents
In `@CLAUDE.md` at line 79, The example in the table uses the wrong definition
name; replace the example value `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1`
with the actual JSON definition name
`dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` (or vice versa if the JSON
should be renamed) and ensure any other occurrences of `dsa_paged` examples
reference `dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1` so the
documentation and JSON definition names are consistent.

| `gdn` | Gated Delta Net (linear attention) | `gdn_decode_qk16_v32_d128_k_last` |
| `moe` | Mixture of Experts | `moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048` |
| `sampling` | Sampling operations | - |
Expand Down Expand Up @@ -168,6 +169,7 @@ Associate each module with corresponding Definitions:
- **Attention layers**:
- GQA: `gqa_paged_decode_h{num_heads}_kv{kv_heads}_d{head_dim}_ps1`
- MLA: `mla_paged_decode_h{num_heads}_ckv{ckv_dim}_kpe{kpe_dim}_ps1`
- DSA: `dsa_sparse_decode_h{num_heads}_ckv{ckv_dim}_kpe{kpe_dim}_topk{topk}_ps1` (sparse MLA)
- GDN: `gdn_decode_qk{q_heads}_v{v_heads}_d{head_dim}` (linear attention)
- **GEMM layers**: `gemm_n_{out_dim}_k_{in_dim}`
- **MoE layers**: `moe_fp8_block_scale_ds_routing_topk{topk}_ng{num_groups}_kg{group_size}_e{num_experts}_h{hidden}_i{intermediate}`
Expand Down
55 changes: 55 additions & 0 deletions docs/op_type_schema/dsa_paged.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# dsa_paged

DeepSeek Sparse Attention (DSA) with paged memory layout. DSA is a two-stage sparse attention mechanism: first an indexer selects top-K relevant KV cache entries using ReLU scoring, then MLA-style attention is performed only on selected entries. This reduces attention computation from O(n) to O(k) where k << n.

Variants:
- indexer
- sparse_attention

## indexer

Computes sparse attention scores using ReLU activation and learned weights, then selects top-K KV cache indices. Uses FP8 quantization with deep_gemm format.

Axes (9 dimensions):
- `batch_size`, `max_num_pages`, `num_pages`: variable
- `num_index_heads`, `index_head_dim`, `page_size`, `topk`, `kv_cache_num_heads`, `head_dim_with_scale`: constant

Inputs (5 tensors):
- `q_index_fp8`: FP8 query for indexing [batch_size, num_index_heads, index_head_dim]
- `k_index_cache_fp8`: FP8 key index cache with scales [num_pages, page_size, kv_cache_num_heads, head_dim_with_scale]
- `weights`: learned head weights [batch_size, num_index_heads]
- `seq_lens`: sequence lengths [batch_size]
- `block_table`: page mapping [batch_size, max_num_pages]

Outputs (1 tensor):
- `topk_indices`: selected token indices [batch_size, topk], -1 indicates padding

Constraints:
- `topk <= max_num_pages * page_size`
- `num_index_heads == 64`, `index_head_dim == 128` (deep_gemm requirement)
- `head_dim_with_scale == 132` (128 + 4 scale bytes)

## sparse_attention

Performs MLA-style attention on top-K selected KV entries. Works for both prefill (multiple tokens) and decode (one token per sequence) - the computation is identical, only the first dimension differs.

Axes (7 dimensions):
- `num_tokens`, `num_pages`: variable
- `num_qo_heads`, `head_dim_ckv`, `head_dim_kpe`, `page_size`, `topk`: constant

Inputs (5 tensors + 1 scalar):
- `q_nope`: query without positional encoding [num_tokens, num_qo_heads, head_dim_ckv]
- `q_pe`: query positional encoding [num_tokens, num_qo_heads, head_dim_kpe]
- `ckv_cache`: compressed KV cache [num_pages, page_size, head_dim_ckv]
- `kpe_cache`: key positional encoding cache [num_pages, page_size, head_dim_kpe]
- `sparse_indices`: top-K indices per token [num_tokens, topk], -1 indicates padding
- `sm_scale`: softmax scale (scalar)

Outputs (2 tensors):
- `output`: attention output [num_tokens, num_qo_heads, head_dim_ckv]
- `lse`: 2-based log-sum-exp [num_tokens, num_qo_heads]

Constraints:
- `sparse_indices.shape[0] == num_tokens`
- `sparse_indices.shape[-1] == topk`
- `ckv_cache.shape[1] == page_size`
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
{
"name": "dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1",
"description": "Batched Native Sparse Attention (DSA) decode with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation.",
"name": "dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1",
"description": "Batched Native Sparse Attention (DSA) with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Works for both prefill and decode stages.",
"op_type": "dsa_paged",
"tags": [
"stage:decode",
"status:verified",
"model:deepseek-v3",
"model:deepseek-r1",
"model:deepseek-v3.2",
"sparse:topk"
],
"axes": {
"batch_size": {
"num_tokens": {
"type": "var",
"description": "Batch size (number of sequences)."
"description": "Number of tokens (batch_size for decode, total_num_tokens for prefill)."
},
"num_qo_heads": {
"type": "const",
Expand Down Expand Up @@ -45,13 +43,14 @@
}
},
"constraints": [
"sparse_indices.shape[0] == num_tokens",
"sparse_indices.shape[-1] == topk",
"ckv_cache.shape[1] == page_size"
],
"inputs": {
"q_nope": {
"shape": [
"batch_size",
"num_tokens",
"num_qo_heads",
"head_dim_ckv"
],
Expand All @@ -60,7 +59,7 @@
},
"q_pe": {
"shape": [
"batch_size",
"num_tokens",
"num_qo_heads",
"head_dim_kpe"
],
Expand All @@ -87,11 +86,11 @@
},
"sparse_indices": {
"shape": [
"batch_size",
"num_tokens",
"topk"
],
"dtype": "int32",
"description": "Sparse indices selecting top-K KV cache entries for each batch element. Values of -1 indicate padding (invalid indices)."
"description": "Sparse indices selecting top-K KV cache entries for each token. Values of -1 indicate padding (invalid indices)."
},
"sm_scale": {
"shape": null,
Expand All @@ -102,20 +101,20 @@
"outputs": {
"output": {
"shape": [
"batch_size",
"num_tokens",
"num_qo_heads",
"head_dim_ckv"
],
"dtype": "bfloat16"
},
"lse": {
"shape": [
"batch_size",
"num_tokens",
"num_qo_heads"
],
"dtype": "float32",
"description": "The 2-based log-sum-exp of attention logits."
}
},
"reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return output, lse"
"reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse"
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The change in the return type of the run function from a dictionary to a tuple aligns with the PR's objective to unify return types. This improves consistency across the API.

    return output, lse

Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
{
"name": "dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64",
"description": "Batched Native Sparse Attention (DSA) decode with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Page size 64 variant.",
"name": "dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64",
"description": "Batched Native Sparse Attention (DSA) with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Page size 64 variant. Works for both prefill and decode stages.",
"op_type": "dsa_paged",
"tags": [
"stage:decode",
"status:verified",
"model:deepseek-v3",
"model:deepseek-r1",
"model:deepseek-v3.2",
"sparse:topk"
],
"axes": {
"batch_size": {
"num_tokens": {
"type": "var",
"description": "Batch size (number of sequences)."
"description": "Number of tokens (batch_size for decode, total_num_tokens for prefill)."
},
"num_qo_heads": {
"type": "const",
Expand Down Expand Up @@ -45,13 +43,14 @@
}
},
"constraints": [
"sparse_indices.shape[0] == num_tokens",
"sparse_indices.shape[-1] == topk",
"ckv_cache.shape[1] == page_size"
],
"inputs": {
"q_nope": {
"shape": [
"batch_size",
"num_tokens",
"num_qo_heads",
"head_dim_ckv"
],
Expand All @@ -60,7 +59,7 @@
},
"q_pe": {
"shape": [
"batch_size",
"num_tokens",
"num_qo_heads",
"head_dim_kpe"
],
Expand All @@ -87,11 +86,11 @@
},
"sparse_indices": {
"shape": [
"batch_size",
"num_tokens",
"topk"
],
"dtype": "int32",
"description": "Sparse indices selecting top-K KV cache entries for each batch element. Values of -1 indicate padding (invalid indices). For page_size=64, indices encode (page_idx * 64 + offset)."
"description": "Sparse indices selecting top-K KV cache entries for each token. Values of -1 indicate padding (invalid indices). For page_size=64, indices encode (page_idx * 64 + offset)."
},
"sm_scale": {
"shape": null,
Expand All @@ -102,20 +101,20 @@
"outputs": {
"output": {
"shape": [
"batch_size",
"num_tokens",
"num_qo_heads",
"head_dim_ckv"
],
"dtype": "bfloat16"
},
"lse": {
"shape": [
"batch_size",
"num_tokens",
"num_qo_heads"
],
"dtype": "float32",
"description": "The 2-based log-sum-exp of attention logits."
}
},
"reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return {\"output\": output, \"lse\": lse}"
"reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse"
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The change in the return type of the run function from a dictionary to a tuple aligns with the PR's objective to unify return types. This improves consistency across the API.

    return output, lse

Loading