diff --git a/CLAUDE.md b/CLAUDE.md index 9b906c9d..4b3cfa73 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -75,7 +75,8 @@ FlashInfer-Bench supports the following op_types (corresponding to different Def | `gemm` | General Matrix Multiplication | `gemm_n_6144_k_4096` | | `gqa_ragged` | Group Query Attention (ragged) | `gqa_ragged_prefill_causal_h32_kv8_d128` | | `gqa_paged` | Group Query Attention (paged) | `gqa_paged_decode_h32_kv8_d128_ps1` | -| `mla_paged` | Multi-Head Linear Attention | `mla_paged_decode_h16_ckv512_kpe64_ps1` | +| `mla_paged` | Multi-Head Latent Attention (paged) | `mla_paged_decode_h16_ckv512_kpe64_ps1` | +| `dsa_paged` | DeepSeek Sparse Attention (paged) | `dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1` | | `gdn` | Gated Delta Net (linear attention) | `gdn_decode_qk16_v32_d128_k_last` | | `moe` | Mixture of Experts | `moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048` | | `sampling` | Sampling operations | - | @@ -168,6 +169,7 @@ Associate each module with corresponding Definitions: - **Attention layers**: - GQA: `gqa_paged_decode_h{num_heads}_kv{kv_heads}_d{head_dim}_ps1` - MLA: `mla_paged_decode_h{num_heads}_ckv{ckv_dim}_kpe{kpe_dim}_ps1` + - DSA: `dsa_sparse_decode_h{num_heads}_ckv{ckv_dim}_kpe{kpe_dim}_topk{topk}_ps1` (sparse MLA) - GDN: `gdn_decode_qk{q_heads}_v{v_heads}_d{head_dim}` (linear attention) - **GEMM layers**: `gemm_n_{out_dim}_k_{in_dim}` - **MoE layers**: `moe_fp8_block_scale_ds_routing_topk{topk}_ng{num_groups}_kg{group_size}_e{num_experts}_h{hidden}_i{intermediate}` diff --git a/docs/op_type_schema/dsa_paged.md b/docs/op_type_schema/dsa_paged.md new file mode 100644 index 00000000..2c8a2517 --- /dev/null +++ b/docs/op_type_schema/dsa_paged.md @@ -0,0 +1,55 @@ +# dsa_paged + +DeepSeek Sparse Attention (DSA) with paged memory layout. DSA is a two-stage sparse attention mechanism: first an indexer selects top-K relevant KV cache entries using ReLU scoring, then MLA-style attention is performed only on selected entries. This reduces attention computation from O(n) to O(k) where k << n. + +Variants: +- indexer +- sparse_attention + +## indexer + +Computes sparse attention scores using ReLU activation and learned weights, then selects top-K KV cache indices. Uses FP8 quantization with deep_gemm format. + +Axes (9 dimensions): +- `batch_size`, `max_num_pages`, `num_pages`: variable +- `num_index_heads`, `index_head_dim`, `page_size`, `topk`, `kv_cache_num_heads`, `head_dim_with_scale`: constant + +Inputs (5 tensors): +- `q_index_fp8`: FP8 query for indexing [batch_size, num_index_heads, index_head_dim] +- `k_index_cache_fp8`: FP8 key index cache with scales [num_pages, page_size, kv_cache_num_heads, head_dim_with_scale] +- `weights`: learned head weights [batch_size, num_index_heads] +- `seq_lens`: sequence lengths [batch_size] +- `block_table`: page mapping [batch_size, max_num_pages] + +Outputs (1 tensor): +- `topk_indices`: selected token indices [batch_size, topk], -1 indicates padding + +Constraints: +- `topk <= max_num_pages * page_size` +- `num_index_heads == 64`, `index_head_dim == 128` (deep_gemm requirement) +- `head_dim_with_scale == 132` (128 + 4 scale bytes) + +## sparse_attention + +Performs MLA-style attention on top-K selected KV entries. Works for both prefill (multiple tokens) and decode (one token per sequence) - the computation is identical, only the first dimension differs. + +Axes (7 dimensions): +- `num_tokens`, `num_pages`: variable +- `num_qo_heads`, `head_dim_ckv`, `head_dim_kpe`, `page_size`, `topk`: constant + +Inputs (5 tensors + 1 scalar): +- `q_nope`: query without positional encoding [num_tokens, num_qo_heads, head_dim_ckv] +- `q_pe`: query positional encoding [num_tokens, num_qo_heads, head_dim_kpe] +- `ckv_cache`: compressed KV cache [num_pages, page_size, head_dim_ckv] +- `kpe_cache`: key positional encoding cache [num_pages, page_size, head_dim_kpe] +- `sparse_indices`: top-K indices per token [num_tokens, topk], -1 indicates padding +- `sm_scale`: softmax scale (scalar) + +Outputs (2 tensors): +- `output`: attention output [num_tokens, num_qo_heads, head_dim_ckv] +- `lse`: 2-based log-sum-exp [num_tokens, num_qo_heads] + +Constraints: +- `sparse_indices.shape[0] == num_tokens` +- `sparse_indices.shape[-1] == topk` +- `ckv_cache.shape[1] == page_size` diff --git a/flashinfer_trace/definitions/dsa_paged/dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1.json b/flashinfer_trace/definitions/dsa_paged/dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.json similarity index 54% rename from flashinfer_trace/definitions/dsa_paged/dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1.json rename to flashinfer_trace/definitions/dsa_paged/dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.json index 95e8f44c..45e764fb 100644 --- a/flashinfer_trace/definitions/dsa_paged/dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1.json +++ b/flashinfer_trace/definitions/dsa_paged/dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.json @@ -1,18 +1,16 @@ { - "name": "dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1", - "description": "Batched Native Sparse Attention (DSA) decode with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation.", + "name": "dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1", + "description": "Batched Native Sparse Attention (DSA) with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Works for both prefill and decode stages.", "op_type": "dsa_paged", "tags": [ - "stage:decode", "status:verified", - "model:deepseek-v3", - "model:deepseek-r1", + "model:deepseek-v3.2", "sparse:topk" ], "axes": { - "batch_size": { + "num_tokens": { "type": "var", - "description": "Batch size (number of sequences)." + "description": "Number of tokens (batch_size for decode, total_num_tokens for prefill)." }, "num_qo_heads": { "type": "const", @@ -45,13 +43,14 @@ } }, "constraints": [ + "sparse_indices.shape[0] == num_tokens", "sparse_indices.shape[-1] == topk", "ckv_cache.shape[1] == page_size" ], "inputs": { "q_nope": { "shape": [ - "batch_size", + "num_tokens", "num_qo_heads", "head_dim_ckv" ], @@ -60,7 +59,7 @@ }, "q_pe": { "shape": [ - "batch_size", + "num_tokens", "num_qo_heads", "head_dim_kpe" ], @@ -87,11 +86,11 @@ }, "sparse_indices": { "shape": [ - "batch_size", + "num_tokens", "topk" ], "dtype": "int32", - "description": "Sparse indices selecting top-K KV cache entries for each batch element. Values of -1 indicate padding (invalid indices)." + "description": "Sparse indices selecting top-K KV cache entries for each token. Values of -1 indicate padding (invalid indices)." }, "sm_scale": { "shape": null, @@ -102,7 +101,7 @@ "outputs": { "output": { "shape": [ - "batch_size", + "num_tokens", "num_qo_heads", "head_dim_ckv" ], @@ -110,12 +109,12 @@ }, "lse": { "shape": [ - "batch_size", + "num_tokens", "num_qo_heads" ], "dtype": "float32", "description": "The 2-based log-sum-exp of attention logits." } }, - "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return output, lse" + "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse" } diff --git a/flashinfer_trace/definitions/dsa_paged/dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64.json b/flashinfer_trace/definitions/dsa_paged/dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.json similarity index 51% rename from flashinfer_trace/definitions/dsa_paged/dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64.json rename to flashinfer_trace/definitions/dsa_paged/dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.json index 2ea2edf1..9355773a 100644 --- a/flashinfer_trace/definitions/dsa_paged/dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64.json +++ b/flashinfer_trace/definitions/dsa_paged/dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.json @@ -1,18 +1,16 @@ { - "name": "dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64", - "description": "Batched Native Sparse Attention (DSA) decode with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Page size 64 variant.", + "name": "dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64", + "description": "Batched Native Sparse Attention (DSA) with sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation. Page size 64 variant. Works for both prefill and decode stages.", "op_type": "dsa_paged", "tags": [ - "stage:decode", "status:verified", - "model:deepseek-v3", - "model:deepseek-r1", + "model:deepseek-v3.2", "sparse:topk" ], "axes": { - "batch_size": { + "num_tokens": { "type": "var", - "description": "Batch size (number of sequences)." + "description": "Number of tokens (batch_size for decode, total_num_tokens for prefill)." }, "num_qo_heads": { "type": "const", @@ -45,13 +43,14 @@ } }, "constraints": [ + "sparse_indices.shape[0] == num_tokens", "sparse_indices.shape[-1] == topk", "ckv_cache.shape[1] == page_size" ], "inputs": { "q_nope": { "shape": [ - "batch_size", + "num_tokens", "num_qo_heads", "head_dim_ckv" ], @@ -60,7 +59,7 @@ }, "q_pe": { "shape": [ - "batch_size", + "num_tokens", "num_qo_heads", "head_dim_kpe" ], @@ -87,11 +86,11 @@ }, "sparse_indices": { "shape": [ - "batch_size", + "num_tokens", "topk" ], "dtype": "int32", - "description": "Sparse indices selecting top-K KV cache entries for each batch element. Values of -1 indicate padding (invalid indices). For page_size=64, indices encode (page_idx * 64 + offset)." + "description": "Sparse indices selecting top-K KV cache entries for each token. Values of -1 indicate padding (invalid indices). For page_size=64, indices encode (page_idx * 64 + offset)." }, "sm_scale": { "shape": null, @@ -102,7 +101,7 @@ "outputs": { "output": { "shape": [ - "batch_size", + "num_tokens", "num_qo_heads", "head_dim_ckv" ], @@ -110,12 +109,12 @@ }, "lse": { "shape": [ - "batch_size", + "num_tokens", "num_qo_heads" ], "dtype": "float32", "description": "The 2-based log-sum-exp of attention logits." } }, - "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for b in range(batch_size):\n indices = sparse_indices[b] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[b].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[b] = out.to(torch.bfloat16)\n\n return {\"output\": output, \"lse\": lse}" + "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse" } diff --git a/flashinfer_trace/definitions/dsa_paged/dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps1.json b/flashinfer_trace/definitions/dsa_paged/dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps1.json deleted file mode 100644 index 2cfa9b72..00000000 --- a/flashinfer_trace/definitions/dsa_paged/dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps1.json +++ /dev/null @@ -1,123 +0,0 @@ -{ - "name": "dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps1", - "description": "Batched Native Sparse Attention (DSA) prefill with causal masking and sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation during prefill.", - "op_type": "dsa_paged", - "tags": [ - "stage:prefill", - "status:verified", - "model:deepseek-v3", - "model:deepseek-r1", - "sparse:topk", - "mask:causal" - ], - "axes": { - "total_num_tokens": { - "type": "var", - "description": "Total number of tokens across all sequences in the batch." - }, - "num_qo_heads": { - "type": "const", - "value": 16, - "description": "Number of query heads after tensor parallel split (128/8=16)." - }, - "head_dim_ckv": { - "type": "const", - "value": 512, - "description": "Compressed KV head dimension." - }, - "head_dim_kpe": { - "type": "const", - "value": 64, - "description": "Key positional encoding dimension." - }, - "page_size": { - "type": "const", - "value": 1, - "description": "Page size for KV cache (token-level)." - }, - "topk": { - "type": "const", - "value": 256, - "description": "Number of top-K KV cache entries selected for sparse attention per token." - }, - "num_pages": { - "type": "var", - "description": "Total number of allocated pages in the KV cache." - } - }, - "constraints": [ - "sparse_indices.shape[0] == total_num_tokens", - "sparse_indices.shape[-1] == topk", - "ckv_cache.shape[1] == page_size" - ], - "inputs": { - "q_nope": { - "shape": [ - "total_num_tokens", - "num_qo_heads", - "head_dim_ckv" - ], - "dtype": "bfloat16", - "description": "Query tensor without positional encoding component." - }, - "q_pe": { - "shape": [ - "total_num_tokens", - "num_qo_heads", - "head_dim_kpe" - ], - "dtype": "bfloat16", - "description": "Query positional encoding component." - }, - "ckv_cache": { - "shape": [ - "num_pages", - "page_size", - "head_dim_ckv" - ], - "dtype": "bfloat16", - "description": "Compressed key-value cache with page_size=1." - }, - "kpe_cache": { - "shape": [ - "num_pages", - "page_size", - "head_dim_kpe" - ], - "dtype": "bfloat16", - "description": "Key positional encoding cache." - }, - "sparse_indices": { - "shape": [ - "total_num_tokens", - "topk" - ], - "dtype": "int32", - "description": "Sparse indices selecting top-K KV cache entries for each token. Values of -1 indicate padding (invalid indices)." - }, - "sm_scale": { - "shape": null, - "dtype": "float32", - "description": "Softmax scale. For MLA, uses pre-absorption head dimension: 1/sqrt(head_dim_qk + head_dim_kpe) = 1/sqrt(128 + 64) = 1/sqrt(192)." - } - }, - "outputs": { - "output": { - "shape": [ - "total_num_tokens", - "num_qo_heads", - "head_dim_ckv" - ], - "dtype": "bfloat16" - }, - "lse": { - "shape": [ - "total_num_tokens", - "num_qo_heads" - ], - "dtype": "float32", - "description": "The 2-based log-sum-exp of attention logits." - } - }, - "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n total_num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n page_size = ckv_cache.shape[1]\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 1\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == total_num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Squeeze page dimension (page_size=1)\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (total_num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((total_num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(total_num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return output, lse" -} diff --git a/flashinfer_trace/definitions/dsa_paged/dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps64.json b/flashinfer_trace/definitions/dsa_paged/dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps64.json deleted file mode 100644 index 03bd3e85..00000000 --- a/flashinfer_trace/definitions/dsa_paged/dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps64.json +++ /dev/null @@ -1,123 +0,0 @@ -{ - "name": "dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps64", - "description": "Batched Native Sparse Attention (DSA) prefill with causal masking and sparse TopK KV cache selection. Captured from DeepSeek-V3 with tensor parallel size 8. Uses sparse indexing to select only top-K KV cache entries for attention computation during prefill. Page size 64 variant.", - "op_type": "dsa_paged", - "tags": [ - "stage:prefill", - "status:verified", - "model:deepseek-v3", - "model:deepseek-r1", - "sparse:topk", - "mask:causal" - ], - "axes": { - "total_num_tokens": { - "type": "var", - "description": "Total number of tokens across all sequences in the batch." - }, - "num_qo_heads": { - "type": "const", - "value": 16, - "description": "Number of query heads after tensor parallel split (128/8=16)." - }, - "head_dim_ckv": { - "type": "const", - "value": 512, - "description": "Compressed KV head dimension." - }, - "head_dim_kpe": { - "type": "const", - "value": 64, - "description": "Key positional encoding dimension." - }, - "page_size": { - "type": "const", - "value": 64, - "description": "Page size for KV cache (64 tokens per page)." - }, - "topk": { - "type": "const", - "value": 256, - "description": "Number of top-K KV cache entries selected for sparse attention per token." - }, - "num_pages": { - "type": "var", - "description": "Total number of allocated pages in the KV cache." - } - }, - "constraints": [ - "sparse_indices.shape[0] == total_num_tokens", - "sparse_indices.shape[-1] == topk", - "ckv_cache.shape[1] == page_size" - ], - "inputs": { - "q_nope": { - "shape": [ - "total_num_tokens", - "num_qo_heads", - "head_dim_ckv" - ], - "dtype": "bfloat16", - "description": "Query tensor without positional encoding component." - }, - "q_pe": { - "shape": [ - "total_num_tokens", - "num_qo_heads", - "head_dim_kpe" - ], - "dtype": "bfloat16", - "description": "Query positional encoding component." - }, - "ckv_cache": { - "shape": [ - "num_pages", - "page_size", - "head_dim_ckv" - ], - "dtype": "bfloat16", - "description": "Compressed key-value cache with page_size=64." - }, - "kpe_cache": { - "shape": [ - "num_pages", - "page_size", - "head_dim_kpe" - ], - "dtype": "bfloat16", - "description": "Key positional encoding cache." - }, - "sparse_indices": { - "shape": [ - "total_num_tokens", - "topk" - ], - "dtype": "int32", - "description": "Sparse indices selecting top-K KV cache entries for each token. Values of -1 indicate padding (invalid indices). For page_size=64, indices encode (page_idx * 64 + offset)." - }, - "sm_scale": { - "shape": null, - "dtype": "float32", - "description": "Softmax scale. For MLA, uses pre-absorption head dimension: 1/sqrt(head_dim_qk + head_dim_kpe) = 1/sqrt(128 + 64) = 1/sqrt(192)." - } - }, - "outputs": { - "output": { - "shape": [ - "total_num_tokens", - "num_qo_heads", - "head_dim_ckv" - ], - "dtype": "bfloat16" - }, - "lse": { - "shape": [ - "total_num_tokens", - "num_qo_heads" - ], - "dtype": "float32", - "description": "The 2-based log-sum-exp of attention logits." - } - }, - "reference": "import math\nimport torch\n\n\n@torch.no_grad()\ndef run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):\n total_num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape\n head_dim_kpe = q_pe.shape[-1]\n num_pages, page_size, _ = ckv_cache.shape\n topk = sparse_indices.shape[-1]\n\n # Check constants\n assert num_qo_heads == 16\n assert head_dim_ckv == 512\n assert head_dim_kpe == 64\n assert page_size == 64\n assert topk == 256\n\n # Check constraints\n assert sparse_indices.shape[0] == total_num_tokens\n assert sparse_indices.shape[-1] == topk\n assert ckv_cache.shape[1] == page_size\n\n device = q_nope.device\n\n # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim]\n Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) # [total_kv_tokens, head_dim_ckv]\n Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) # [total_kv_tokens, head_dim_kpe]\n\n output = torch.zeros(\n (total_num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device\n )\n lse = torch.full((total_num_tokens, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=device)\n\n for t in range(total_num_tokens):\n indices = sparse_indices[t] # [topk]\n\n # Handle padding: -1 indicates invalid indices\n valid_mask = indices != -1\n valid_indices = indices[valid_mask]\n\n if valid_indices.numel() == 0:\n output[t].zero_()\n continue\n\n # For page_size=64, indices encode (page_idx * 64 + offset)\n tok_idx = valid_indices.to(torch.long)\n\n Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]\n qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n\n # Compute attention logits\n logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]\n logits_scaled = logits * sm_scale\n\n # Compute 2-base LSE\n lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)\n\n # Compute attention output\n attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]\n out = attn @ Kc # [num_qo_heads, head_dim_ckv]\n output[t] = out.to(torch.bfloat16)\n\n return {\"output\": output, \"lse\": lse}" -} diff --git a/flashinfer_trace/definitions/dsa_paged/dsa_topk_indexer_fp8_h64_d128_topk256_ps64.json b/flashinfer_trace/definitions/dsa_paged/dsa_topk_indexer_fp8_h64_d128_topk256_ps64.json index 680a4197..6c95d092 100644 --- a/flashinfer_trace/definitions/dsa_paged/dsa_topk_indexer_fp8_h64_d128_topk256_ps64.json +++ b/flashinfer_trace/definitions/dsa_paged/dsa_topk_indexer_fp8_h64_d128_topk256_ps64.json @@ -109,5 +109,5 @@ "description": "Top-K token indices for each batch element. Values of -1 indicate padding." } }, - "reference": "import torch\n\n\ndef dequant_fp8_kv_cache(k_index_cache_fp8):\n \"\"\"Dequantize FP8 KV cache from deep_gemm format.\n \n Input: [num_pages, page_size, 1, 132] int8 (interpreted as uint8)\n Memory layout (per page): [fp8_data (page_size * 128 bytes), scales (page_size * 4 bytes)]\n After view to [num_pages, page_size, 1, 132]: NOT directly indexable as [fp8, scale] per token!\n Output: [num_pages, page_size, 128] float32\n \"\"\"\n # View as uint8 for correct byte interpretation\n k_index_cache_fp8 = k_index_cache_fp8.view(torch.uint8)\n num_pages, page_size, num_heads, head_dim_sf = k_index_cache_fp8.shape\n head_dim = head_dim_sf - 4 # 128\n \n # Go back to flat format to reverse the packing\n kv_flat = k_index_cache_fp8.view(num_pages, page_size * head_dim_sf)\n \n # FP8 part: first page_size * head_dim bytes\n fp8_bytes = kv_flat[:, :page_size * head_dim].contiguous()\n fp8_tensor = fp8_bytes.view(num_pages, page_size, head_dim).view(torch.float8_e4m3fn)\n fp8_float = fp8_tensor.to(torch.float32)\n \n # Scale part: last page_size * 4 bytes -> page_size float32 values\n scale_bytes = kv_flat[:, page_size * head_dim:].contiguous()\n scale = scale_bytes.view(num_pages, page_size, 4).view(torch.float32) # [num_pages, page_size, 1]\n \n return fp8_float * scale\n\n\n@torch.no_grad()\ndef run(q_index_fp8, k_index_cache_fp8, weights, seq_lens, block_table):\n batch_size, num_index_heads, index_head_dim = q_index_fp8.shape\n num_pages, page_size, _, _ = k_index_cache_fp8.shape\n topk = 256\n\n # Check constants\n assert num_index_heads == 64\n assert index_head_dim == 128\n assert page_size == 64\n\n device = q_index_fp8.device\n\n # Dequantize inputs\n q = q_index_fp8.to(torch.float32) # [batch, heads, head_dim]\n K_all = dequant_fp8_kv_cache(k_index_cache_fp8) # [num_pages, page_size, head_dim]\n\n topk_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device)\n max_num_pages = block_table.shape[1]\n\n for b in range(batch_size):\n seq_len = int(seq_lens[b].item())\n \n if seq_len == 0:\n continue\n\n # Get pages for this sequence\n num_pages_for_seq = (seq_len + page_size - 1) // page_size\n page_indices = block_table[b, :num_pages_for_seq].to(torch.long)\n \n # Gather K from pages\n K_paged = K_all[page_indices] # [num_pages_for_seq, page_size, head_dim]\n K = K_paged.reshape(-1, index_head_dim)[:seq_len] # [seq_len, head_dim]\n \n # Query for this batch element\n q_b = q[b] # [num_heads, head_dim]\n \n # Compute attention scores\n scores = q_b @ K.T # [num_heads, seq_len]\n \n # Apply ReLU (deep_gemm uses ReLU activation)\n scores_relu = torch.relu(scores) # [num_heads, seq_len]\n \n # Apply learned weights and sum across heads\n w = weights[b] # [num_heads]\n weighted_scores = scores_relu * w[:, None] # [num_heads, seq_len]\n final_scores = weighted_scores.sum(dim=0) # [seq_len]\n \n # Select top-K\n actual_topk = min(topk, seq_len)\n _, topk_idx = torch.topk(final_scores, actual_topk)\n \n # Convert to global token indices\n # Token index = page_idx * page_size + offset_in_page\n page_idx_per_token = topk_idx // page_size\n offset_per_token = topk_idx % page_size\n global_page_idx = page_indices[page_idx_per_token]\n topk_tokens = global_page_idx * page_size + offset_per_token\n \n topk_indices[b, :actual_topk] = topk_tokens.to(torch.int32)\n\n return {\"topk_indices\": topk_indices}" + "reference": "import torch\n\n\ndef dequant_fp8_kv_cache(k_index_cache_fp8):\n \"\"\"Dequantize FP8 KV cache from deep_gemm format.\n \n Input: [num_pages, page_size, 1, 132] int8 (interpreted as uint8)\n Memory layout (per page): [fp8_data (page_size * 128 bytes), scales (page_size * 4 bytes)]\n After view to [num_pages, page_size, 1, 132]: NOT directly indexable as [fp8, scale] per token!\n Output: [num_pages, page_size, 128] float32\n \"\"\"\n # View as uint8 for correct byte interpretation\n k_index_cache_fp8 = k_index_cache_fp8.view(torch.uint8)\n num_pages, page_size, num_heads, head_dim_sf = k_index_cache_fp8.shape\n head_dim = head_dim_sf - 4 # 128\n \n # Go back to flat format to reverse the packing\n kv_flat = k_index_cache_fp8.view(num_pages, page_size * head_dim_sf)\n \n # FP8 part: first page_size * head_dim bytes\n fp8_bytes = kv_flat[:, :page_size * head_dim].contiguous()\n fp8_tensor = fp8_bytes.view(num_pages, page_size, head_dim).view(torch.float8_e4m3fn)\n fp8_float = fp8_tensor.to(torch.float32)\n \n # Scale part: last page_size * 4 bytes -> page_size float32 values\n scale_bytes = kv_flat[:, page_size * head_dim:].contiguous()\n scale = scale_bytes.view(num_pages, page_size, 4).view(torch.float32) # [num_pages, page_size, 1]\n \n return fp8_float * scale\n\n\n@torch.no_grad()\ndef run(q_index_fp8, k_index_cache_fp8, weights, seq_lens, block_table):\n batch_size, num_index_heads, index_head_dim = q_index_fp8.shape\n num_pages, page_size, _, _ = k_index_cache_fp8.shape\n topk = 256\n\n # Check constants\n assert num_index_heads == 64\n assert index_head_dim == 128\n assert page_size == 64\n\n device = q_index_fp8.device\n\n # Dequantize inputs\n q = q_index_fp8.to(torch.float32) # [batch, heads, head_dim]\n K_all = dequant_fp8_kv_cache(k_index_cache_fp8) # [num_pages, page_size, head_dim]\n\n topk_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device)\n max_num_pages = block_table.shape[1]\n\n for b in range(batch_size):\n seq_len = int(seq_lens[b].item())\n \n if seq_len == 0:\n continue\n\n # Get pages for this sequence\n num_pages_for_seq = (seq_len + page_size - 1) // page_size\n page_indices = block_table[b, :num_pages_for_seq].to(torch.long)\n \n # Gather K from pages\n K_paged = K_all[page_indices] # [num_pages_for_seq, page_size, head_dim]\n K = K_paged.reshape(-1, index_head_dim)[:seq_len] # [seq_len, head_dim]\n \n # Query for this batch element\n q_b = q[b] # [num_heads, head_dim]\n \n # Compute attention scores\n scores = q_b @ K.T # [num_heads, seq_len]\n \n # Apply ReLU (deep_gemm uses ReLU activation)\n scores_relu = torch.relu(scores) # [num_heads, seq_len]\n \n # Apply learned weights and sum across heads\n w = weights[b] # [num_heads]\n weighted_scores = scores_relu * w[:, None] # [num_heads, seq_len]\n final_scores = weighted_scores.sum(dim=0) # [seq_len]\n \n # Select top-K\n actual_topk = min(topk, seq_len)\n _, topk_idx = torch.topk(final_scores, actual_topk)\n \n # Convert to global token indices\n # Token index = page_idx * page_size + offset_in_page\n page_idx_per_token = topk_idx // page_size\n offset_per_token = topk_idx % page_size\n global_page_idx = page_indices[page_idx_per_token]\n topk_tokens = global_page_idx * page_size + offset_per_token\n \n topk_indices[b, :actual_topk] = topk_tokens.to(torch.int32)\n\n return (topk_indices,)" } diff --git a/flashinfer_trace/definitions/gdn/gdn_decode_qk16_v32_d128_k_last.json b/flashinfer_trace/definitions/gdn/gdn_decode_qk16_v32_d128_k_last.json index d5eeac1d..d1fc27c6 100644 --- a/flashinfer_trace/definitions/gdn/gdn_decode_qk16_v32_d128_k_last.json +++ b/flashinfer_trace/definitions/gdn/gdn_decode_qk16_v32_d128_k_last.json @@ -45,43 +45,75 @@ ], "inputs": { "q": { - "shape": ["batch_size", "seq_len", "num_q_heads", "head_size"], + "shape": [ + "batch_size", + "seq_len", + "num_q_heads", + "head_size" + ], "dtype": "bfloat16", "description": "Query tensor for single token decode." }, "k": { - "shape": ["batch_size", "seq_len", "num_k_heads", "head_size"], + "shape": [ + "batch_size", + "seq_len", + "num_k_heads", + "head_size" + ], "dtype": "bfloat16", "description": "Key tensor for single token decode." }, "v": { - "shape": ["batch_size", "seq_len", "num_v_heads", "head_size"], + "shape": [ + "batch_size", + "seq_len", + "num_v_heads", + "head_size" + ], "dtype": "bfloat16", "description": "Value tensor for single token decode." }, "state": { - "shape": ["batch_size", "num_v_heads", "head_size", "head_size"], + "shape": [ + "batch_size", + "num_v_heads", + "head_size", + "head_size" + ], "dtype": "float32", "description": "Recurrent state in k-last layout [B, H, V, K].", "optional": true }, "A_log": { - "shape": ["num_v_heads"], + "shape": [ + "num_v_heads" + ], "dtype": "float32", "description": "Log decay parameter (learnable). Used to compute g = exp(-exp(A_log) * softplus(a + dt_bias))." }, "a": { - "shape": ["batch_size", "seq_len", "num_v_heads"], + "shape": [ + "batch_size", + "seq_len", + "num_v_heads" + ], "dtype": "bfloat16", "description": "Input-dependent decay from projection." }, "dt_bias": { - "shape": ["num_v_heads"], + "shape": [ + "num_v_heads" + ], "dtype": "float32", "description": "Decay bias (learnable). Added to 'a' before softplus." }, "b": { - "shape": ["batch_size", "seq_len", "num_v_heads"], + "shape": [ + "batch_size", + "seq_len", + "num_v_heads" + ], "dtype": "bfloat16", "description": "Update gate input from projection. beta = sigmoid(b)." }, @@ -93,15 +125,25 @@ }, "outputs": { "output": { - "shape": ["batch_size", "seq_len", "num_v_heads", "head_size"], + "shape": [ + "batch_size", + "seq_len", + "num_v_heads", + "head_size" + ], "dtype": "bfloat16", "description": "Attention output. Shape follows num_v_heads in GVA mode." }, "new_state": { - "shape": ["batch_size", "num_v_heads", "head_size", "head_size"], + "shape": [ + "batch_size", + "num_v_heads", + "head_size", + "head_size" + ], "dtype": "float32", "description": "Updated recurrent state in k-last layout [B, H, V, K]." } }, - "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation (k-last layout).\n \n State layout: [B, H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [B, 1, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, 1, HV]\n beta = torch.sigmoid(b.float()) # [B, 1, HV]\n \n q_f32 = q.squeeze(1).float()\n k_f32 = k.squeeze(1).float()\n v_f32 = v.squeeze(1).float()\n g_f32 = g.squeeze(1).float()\n beta_f32 = beta.squeeze(1).float()\n \n if state is not None:\n state_f32 = state.float()\n else:\n state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)\n \n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n \n new_state = torch.zeros_like(state_f32)\n output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)\n \n for b_idx in range(B):\n for h_idx in range(num_heads):\n q_h = q_exp[b_idx, h_idx]\n k_h = k_exp[b_idx, h_idx]\n v_h = v_f32[b_idx, h_idx]\n h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2) # [V,K] -> [K,V]\n g_val = g_f32[b_idx, h_idx]\n beta_val = beta_f32[b_idx, h_idx]\n \n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n \n output[b_idx, h_idx] = scale * (q_h @ h_state)\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2) # [K,V] -> [V,K]\n \n output = output.unsqueeze(1).to(torch.bfloat16)\n return {\"output\": output, \"new_state\": new_state}" + "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation (k-last layout).\n \n State layout: [B, H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n \n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert K == 128 and V == 128\n assert T == 1\n \n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n \n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [B, 1, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, 1, HV]\n beta = torch.sigmoid(b.float()) # [B, 1, HV]\n \n q_f32 = q.squeeze(1).float()\n k_f32 = k.squeeze(1).float()\n v_f32 = v.squeeze(1).float()\n g_f32 = g.squeeze(1).float()\n beta_f32 = beta.squeeze(1).float()\n \n if state is not None:\n state_f32 = state.float()\n else:\n state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)\n \n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n \n new_state = torch.zeros_like(state_f32)\n output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)\n \n for b_idx in range(B):\n for h_idx in range(num_heads):\n q_h = q_exp[b_idx, h_idx]\n k_h = k_exp[b_idx, h_idx]\n v_h = v_f32[b_idx, h_idx]\n h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2) # [V,K] -> [K,V]\n g_val = g_f32[b_idx, h_idx]\n beta_val = beta_f32[b_idx, h_idx]\n \n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n \n output[b_idx, h_idx] = scale * (q_h @ h_state)\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2) # [K,V] -> [V,K]\n \n output = output.unsqueeze(1).to(torch.bfloat16)\n return output, new_state" } diff --git a/flashinfer_trace/definitions/gdn/gdn_prefill_qk16_v32_d128_k_last.json b/flashinfer_trace/definitions/gdn/gdn_prefill_qk16_v32_d128_k_last.json index 32351b7e..fcdbf63c 100644 --- a/flashinfer_trace/definitions/gdn/gdn_prefill_qk16_v32_d128_k_last.json +++ b/flashinfer_trace/definitions/gdn/gdn_prefill_qk16_v32_d128_k_last.json @@ -45,48 +45,77 @@ ], "inputs": { "q": { - "shape": ["total_seq_len", "num_q_heads", "head_size"], + "shape": [ + "total_seq_len", + "num_q_heads", + "head_size" + ], "dtype": "bfloat16", "description": "Query tensor." }, "k": { - "shape": ["total_seq_len", "num_k_heads", "head_size"], + "shape": [ + "total_seq_len", + "num_k_heads", + "head_size" + ], "dtype": "bfloat16", "description": "Key tensor." }, "v": { - "shape": ["total_seq_len", "num_v_heads", "head_size"], + "shape": [ + "total_seq_len", + "num_v_heads", + "head_size" + ], "dtype": "bfloat16", "description": "Value tensor." }, "state": { - "shape": ["num_seqs", "num_v_heads", "head_size", "head_size"], + "shape": [ + "num_seqs", + "num_v_heads", + "head_size", + "head_size" + ], "dtype": "float32", "description": "Recurrent state in k-last layout [N, H, V, K].", "optional": true }, "A_log": { - "shape": ["num_v_heads"], + "shape": [ + "num_v_heads" + ], "dtype": "float32", "description": "Log decay parameter (learnable). Used to compute g = exp(-exp(A_log) * softplus(a + dt_bias))." }, "a": { - "shape": ["total_seq_len", "num_v_heads"], + "shape": [ + "total_seq_len", + "num_v_heads" + ], "dtype": "bfloat16", "description": "Input-dependent decay from projection." }, "dt_bias": { - "shape": ["num_v_heads"], + "shape": [ + "num_v_heads" + ], "dtype": "float32", "description": "Decay bias (learnable). Added to 'a' before softplus." }, "b": { - "shape": ["total_seq_len", "num_v_heads"], + "shape": [ + "total_seq_len", + "num_v_heads" + ], "dtype": "bfloat16", "description": "Update gate input from projection. beta = sigmoid(b)." }, "cu_seqlens": { - "shape": ["len_cu_seqlens"], + "shape": [ + "len_cu_seqlens" + ], "dtype": "int64", "description": "Cumulative sequence lengths for variable-length batching." }, @@ -98,15 +127,24 @@ }, "outputs": { "output": { - "shape": ["total_seq_len", "num_v_heads", "head_size"], + "shape": [ + "total_seq_len", + "num_v_heads", + "head_size" + ], "dtype": "bfloat16", "description": "Attention output. Shape follows num_v_heads in GVA mode." }, "new_state": { - "shape": ["num_seqs", "num_v_heads", "head_size", "head_size"], + "shape": [ + "num_seqs", + "num_v_heads", + "head_size", + "head_size" + ], "dtype": "float32", "description": "Updated recurrent state in k-last layout [N, H, V, K]." } }, - "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, cu_seqlens, scale):\n \"\"\"\n Gated Delta Net prefill reference implementation (k-last layout).\n \n State layout: [H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n total_seq_len, num_q_heads, head_size = q.shape\n num_v_heads = v.shape[1]\n num_k_heads = k.shape[1]\n num_sab_heads = max(num_q_heads, num_v_heads)\n num_seqs = cu_seqlens.size(0) - 1\n device = q.device\n\n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert head_size == 128\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(head_size)\n\n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [total_seq_len, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [total_seq_len, HV]\n beta = torch.sigmoid(b.float()) # [total_seq_len, HV]\n\n q_exp = q.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n\n output = torch.zeros(\n (total_seq_len, num_sab_heads, head_size), dtype=torch.bfloat16, device=device\n )\n new_state = torch.zeros(\n (num_seqs, num_sab_heads, head_size, head_size), dtype=torch.float32, device=device\n )\n\n for seq_idx in range(num_seqs):\n seq_start = int(cu_seqlens[seq_idx].item())\n seq_end = int(cu_seqlens[seq_idx + 1].item())\n seq_len = seq_end - seq_start\n\n if seq_len <= 0:\n continue\n\n if state is not None:\n state_HKV = state[seq_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V]\n else:\n state_HKV = torch.zeros(\n (num_sab_heads, head_size, head_size), dtype=torch.float32, device=device\n )\n\n for i in range(seq_len):\n t = seq_start + i\n q_H1K = q_exp[t].unsqueeze(1).float()\n k_H1K = k_exp[t].unsqueeze(1).float()\n v_H1V = v[t].unsqueeze(1).float()\n g_H11 = g[t].unsqueeze(1).unsqueeze(2)\n beta_H11 = beta[t].unsqueeze(1).unsqueeze(2)\n\n old_state_HKV = g_H11 * state_HKV\n old_v_H1V = matmul(k_H1K, old_state_HKV)\n new_v_H1V = beta_H11 * v_H1V + (1 - beta_H11) * old_v_H1V\n state_remove = torch.einsum('hkl,hlv->hkv', k_H1K.transpose(-1, -2), old_v_H1V)\n state_update = torch.einsum('hkl,hlv->hkv', k_H1K.transpose(-1, -2), new_v_H1V)\n state_HKV = old_state_HKV - state_remove + state_update\n\n o_H1V = scale * matmul(q_H1K, state_HKV)\n output[t] = o_H1V.squeeze(1).to(torch.bfloat16)\n\n new_state[seq_idx] = state_HKV.transpose(-1, -2) # [H,K,V] -> [H,V,K]\n\n return {\"output\": output, \"new_state\": new_state}" + "reference": "import math\nimport torch\nimport torch.nn.functional as F\n\n\ndef matmul(a: torch.Tensor, b: torch.Tensor):\n \"\"\"Float32 matmul for numerical stability.\"\"\"\n return a.float() @ b.float()\n\n\n@torch.no_grad()\ndef run(q, k, v, state, A_log, a, dt_bias, b, cu_seqlens, scale):\n \"\"\"\n Gated Delta Net prefill reference implementation (k-last layout).\n \n State layout: [H, V, K] (k-last, K dimension at the end)\n \n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n \n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n total_seq_len, num_q_heads, head_size = q.shape\n num_v_heads = v.shape[1]\n num_k_heads = k.shape[1]\n num_sab_heads = max(num_q_heads, num_v_heads)\n num_seqs = cu_seqlens.size(0) - 1\n device = q.device\n\n assert num_q_heads == 16\n assert num_k_heads == 16\n assert num_v_heads == 32\n assert head_size == 128\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(head_size)\n\n # Compute g and beta from raw parameters\n x = a.float() + dt_bias.float() # [total_seq_len, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [total_seq_len, HV]\n beta = torch.sigmoid(b.float()) # [total_seq_len, HV]\n\n q_exp = q.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n\n output = torch.zeros(\n (total_seq_len, num_sab_heads, head_size), dtype=torch.bfloat16, device=device\n )\n new_state = torch.zeros(\n (num_seqs, num_sab_heads, head_size, head_size), dtype=torch.float32, device=device\n )\n\n for seq_idx in range(num_seqs):\n seq_start = int(cu_seqlens[seq_idx].item())\n seq_end = int(cu_seqlens[seq_idx + 1].item())\n seq_len = seq_end - seq_start\n\n if seq_len <= 0:\n continue\n\n if state is not None:\n state_HKV = state[seq_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V]\n else:\n state_HKV = torch.zeros(\n (num_sab_heads, head_size, head_size), dtype=torch.float32, device=device\n )\n\n for i in range(seq_len):\n t = seq_start + i\n q_H1K = q_exp[t].unsqueeze(1).float()\n k_H1K = k_exp[t].unsqueeze(1).float()\n v_H1V = v[t].unsqueeze(1).float()\n g_H11 = g[t].unsqueeze(1).unsqueeze(2)\n beta_H11 = beta[t].unsqueeze(1).unsqueeze(2)\n\n old_state_HKV = g_H11 * state_HKV\n old_v_H1V = matmul(k_H1K, old_state_HKV)\n new_v_H1V = beta_H11 * v_H1V + (1 - beta_H11) * old_v_H1V\n state_remove = torch.einsum('hkl,hlv->hkv', k_H1K.transpose(-1, -2), old_v_H1V)\n state_update = torch.einsum('hkl,hlv->hkv', k_H1K.transpose(-1, -2), new_v_H1V)\n state_HKV = old_state_HKV - state_remove + state_update\n\n o_H1V = scale * matmul(q_H1K, state_HKV)\n output[t] = o_H1V.squeeze(1).to(torch.bfloat16)\n\n new_state[seq_idx] = state_HKV.transpose(-1, -2) # [H,K,V] -> [H,V,K]\n\n return output, new_state" } diff --git a/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.py b/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.py new file mode 100644 index 00000000..773f4911 --- /dev/null +++ b/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.py @@ -0,0 +1,192 @@ +""" +Tests for DSA (DeepSeek Sparse Attention) sparse attention reference implementation. + +Ground truth comparison tests are in test_dsa_vs_definition_reference.py +which tests against FlashInfer's trtllm_batch_decode_with_kv_cache_mla. +""" + +import math +from pathlib import Path + +import numpy as np +import pytest +import torch + +# Module-level constants (DeepSeek V3/R1 with TP=8) +NUM_QO_HEADS = 16 +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 +PAGE_SIZE = 1 +TOPK = 256 + +TRACE_ROOT = Path(__file__).resolve().parents[2] + + +@torch.no_grad() +def run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale): + """Reference implementation for DSA sparse attention.""" + num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape + head_dim_kpe = q_pe.shape[-1] + page_size = ckv_cache.shape[1] + topk = sparse_indices.shape[-1] + + # Check constants + assert num_qo_heads == NUM_QO_HEADS + assert head_dim_ckv == HEAD_DIM_CKV + assert head_dim_kpe == HEAD_DIM_KPE + assert page_size == PAGE_SIZE + assert topk == TOPK + + # Check constraints + assert sparse_indices.shape[0] == num_tokens + assert sparse_indices.shape[-1] == topk + assert ckv_cache.shape[1] == page_size + + device = q_nope.device + + # Squeeze page dimension (page_size=1) + Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv] + Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe] + + output = torch.zeros( + (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device + ) + lse = torch.full((num_tokens, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) + + for t in range(num_tokens): + indices = sparse_indices[t] # [topk] + + # Handle padding: -1 indicates invalid indices + valid_mask = indices != -1 + valid_indices = indices[valid_mask] + + if valid_indices.numel() == 0: + output[t].zero_() + continue + + tok_idx = valid_indices.to(torch.long) + + Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv] + Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe] + qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv] + qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe] + + # Compute attention logits + logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid] + logits_scaled = logits * sm_scale + + # Compute 2-base LSE + lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) + + # Compute attention output + attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid] + out = attn @ Kc # [num_qo_heads, head_dim_ckv] + output[t] = out.to(torch.bfloat16) + + return {"output": output, "lse": lse} + + +def generate_random_inputs( + num_tokens, + num_qo_heads=NUM_QO_HEADS, + head_dim_ckv=HEAD_DIM_CKV, + head_dim_kpe=HEAD_DIM_KPE, + topk=TOPK, + device="cuda", +): + """Generate random inputs for DSA sparse attention testing.""" + num_pages = max(num_tokens * 2, 1024) + + sparse_indices = torch.randint( + 0, num_pages, (num_tokens, topk), dtype=torch.int32, device=device + ) + + q_nope = torch.randn( + num_tokens, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device + ) + q_pe = torch.randn(num_tokens, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) + + ckv_cache = torch.randn(num_pages, 1, head_dim_ckv, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(num_pages, 1, head_dim_kpe, dtype=torch.bfloat16, device=device) + + sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) + + return { + "q_nope": q_nope, + "q_pe": q_pe, + "ckv_cache": ckv_cache, + "kpe_cache": kpe_cache, + "sparse_indices": sparse_indices, + "sm_scale": torch.tensor(sm_scale, dtype=torch.float32, device=device), + "num_pages": num_pages, + } + + +def test_output_shape(num_tokens=64, topk=TOPK): + """Test that reference produces correct output shapes.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + inputs = generate_random_inputs(num_tokens, topk=topk, device=device) + + result = run( + inputs["q_nope"], + inputs["q_pe"], + inputs["ckv_cache"], + inputs["kpe_cache"], + inputs["sparse_indices"], + inputs["sm_scale"], + ) + + output = result["output"] + lse = result["lse"] + + assert output.shape == (num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV) + assert lse.shape == (num_tokens, NUM_QO_HEADS) + + +def test_padding_handling(num_tokens=64, topk=TOPK): + """Test that padding (-1 indices) are handled correctly.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + num_pages = 1000 + + q_nope = torch.randn( + num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device + ) + q_pe = torch.randn(num_tokens, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) + ckv_cache = torch.randn(num_pages, 1, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(num_pages, 1, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) + sm_scale = torch.tensor(1.0 / np.sqrt(128 + HEAD_DIM_KPE), dtype=torch.float32, device=device) + + # Create sparse indices with varying amounts of padding per token + sparse_indices = torch.full((num_tokens, topk), -1, dtype=torch.int32, device=device) + + for t in range(num_tokens): + valid_count = (t % 4 + 1) * (topk // 4) + valid_count = min(valid_count, topk) + sparse_indices[t, :valid_count] = torch.randint( + 0, num_pages, (valid_count,), dtype=torch.int32, device=device + ) + + result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) + output = result["output"] + lse = result["lse"] + + # Verify outputs are valid + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + assert not torch.isnan(lse).any() + + +if __name__ == "__main__": + print("Testing DSA Sparse Attention Reference (page_size=1)") + print( + f"Constants: h={NUM_QO_HEADS}, ckv={HEAD_DIM_CKV}, kpe={HEAD_DIM_KPE}, ps={PAGE_SIZE}, topk={TOPK}" + ) + print("=" * 70) + + test_output_shape() + print("test_output_shape: PASSED") + + test_padding_handling() + print("test_padding_handling: PASSED") + + print("\nAll tests passed!") diff --git a/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.py b/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.py new file mode 100644 index 00000000..685a6916 --- /dev/null +++ b/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.py @@ -0,0 +1,190 @@ +""" +Tests for DSA (DeepSeek Sparse Attention) sparse attention reference implementation. +Page size 64 variant. + +Ground truth comparison tests are in test_dsa_vs_definition_reference.py +which tests against FlashInfer's trtllm_batch_decode_with_kv_cache_mla. +""" + +import math +from pathlib import Path + +import numpy as np +import pytest +import torch + +# Module-level constants (DeepSeek V3/R1 with TP=8) +NUM_QO_HEADS = 16 +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 +PAGE_SIZE = 64 +TOPK = 256 + +TRACE_ROOT = Path(__file__).resolve().parents[2] + + +@torch.no_grad() +def run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale): + """Reference implementation for DSA sparse attention with page_size=64.""" + num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape + head_dim_kpe = q_pe.shape[-1] + num_pages, page_size, _ = ckv_cache.shape + topk = sparse_indices.shape[-1] + + # Check constants + assert num_qo_heads == NUM_QO_HEADS + assert head_dim_ckv == HEAD_DIM_CKV + assert head_dim_kpe == HEAD_DIM_KPE + assert page_size == PAGE_SIZE + assert topk == TOPK + + # Check constraints + assert sparse_indices.shape[0] == num_tokens + assert sparse_indices.shape[-1] == topk + assert ckv_cache.shape[1] == page_size + + device = q_nope.device + + # Flatten paged KV cache to token-level + Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) + Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) + + output = torch.zeros( + (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device + ) + lse = torch.full((num_tokens, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) + + for t in range(num_tokens): + indices = sparse_indices[t] + + valid_mask = indices != -1 + valid_indices = indices[valid_mask] + + if valid_indices.numel() == 0: + output[t].zero_() + continue + + tok_idx = valid_indices.to(torch.long) + + Kc = Kc_all[tok_idx] + Kp = Kp_all[tok_idx] + qn = q_nope[t].to(torch.float32) + qp = q_pe[t].to(torch.float32) + + logits = (qn @ Kc.T) + (qp @ Kp.T) + logits_scaled = logits * sm_scale + + lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) + + attn = torch.softmax(logits_scaled, dim=-1) + out = attn @ Kc + output[t] = out.to(torch.bfloat16) + + return {"output": output, "lse": lse} + + +def generate_random_inputs( + num_tokens, + num_qo_heads=NUM_QO_HEADS, + head_dim_ckv=HEAD_DIM_CKV, + head_dim_kpe=HEAD_DIM_KPE, + topk=TOPK, + device="cuda", +): + """Generate random inputs for DSA sparse attention testing with page_size=64.""" + total_kv_tokens = max(num_tokens * 4, 2048) + num_pages = (total_kv_tokens + PAGE_SIZE - 1) // PAGE_SIZE + + total_tokens_in_cache = num_pages * PAGE_SIZE + sparse_indices = torch.randint( + 0, total_tokens_in_cache, (num_tokens, topk), dtype=torch.int32, device=device + ) + + q_nope = torch.randn( + num_tokens, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device + ) + q_pe = torch.randn(num_tokens, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) + + ckv_cache = torch.randn(num_pages, PAGE_SIZE, head_dim_ckv, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(num_pages, PAGE_SIZE, head_dim_kpe, dtype=torch.bfloat16, device=device) + + sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) + + return { + "q_nope": q_nope, + "q_pe": q_pe, + "ckv_cache": ckv_cache, + "kpe_cache": kpe_cache, + "sparse_indices": sparse_indices, + "sm_scale": torch.tensor(sm_scale, dtype=torch.float32, device=device), + "num_pages": num_pages, + } + + +def test_output_shape(num_tokens=64, topk=TOPK): + """Test that reference produces correct output shapes.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + inputs = generate_random_inputs(num_tokens, topk=topk, device=device) + + result = run( + inputs["q_nope"], + inputs["q_pe"], + inputs["ckv_cache"], + inputs["kpe_cache"], + inputs["sparse_indices"], + inputs["sm_scale"], + ) + + output = result["output"] + lse = result["lse"] + + assert output.shape == (num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV) + assert lse.shape == (num_tokens, NUM_QO_HEADS) + + +def test_padding_handling(num_tokens=64, topk=TOPK): + """Test that padding (-1 indices) are handled correctly.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + num_pages = 64 + + q_nope = torch.randn( + num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device + ) + q_pe = torch.randn(num_tokens, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) + ckv_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) + sm_scale = torch.tensor(1.0 / np.sqrt(128 + HEAD_DIM_KPE), dtype=torch.float32, device=device) + + sparse_indices = torch.full((num_tokens, topk), -1, dtype=torch.int32, device=device) + total_tokens_in_cache = num_pages * PAGE_SIZE + + for t in range(num_tokens): + valid_count = (t % 4 + 1) * (topk // 4) + valid_count = min(valid_count, topk) + sparse_indices[t, :valid_count] = torch.randint( + 0, total_tokens_in_cache, (valid_count,), dtype=torch.int32, device=device + ) + + result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) + output = result["output"] + lse = result["lse"] + + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + assert not torch.isnan(lse).any() + + +if __name__ == "__main__": + print("Testing DSA Sparse Attention Reference (page_size=64)") + print( + f"Constants: h={NUM_QO_HEADS}, ckv={HEAD_DIM_CKV}, kpe={HEAD_DIM_KPE}, ps={PAGE_SIZE}, topk={TOPK}" + ) + print("=" * 70) + + test_output_shape() + print("test_output_shape: PASSED") + + test_padding_handling() + print("test_padding_handling: PASSED") + + print("\nAll tests passed!") diff --git a/flashinfer_trace/tests/references/test_dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1.py b/flashinfer_trace/tests/references/test_dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1.py deleted file mode 100644 index 5769f26c..00000000 --- a/flashinfer_trace/tests/references/test_dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps1.py +++ /dev/null @@ -1,565 +0,0 @@ -""" -Tests for DSA (DeepSeek Sparse Attention) sparse decode reference implementation. - -Ground truth sources: -1. SGLang FlashMLA sparse kernel: sgl_kernel.flash_mla.flash_mla_with_kvcache (decode with indices) -2. SGLang FlashMLA sparse prefill: sgl_kernel.flash_mla.flash_mla_sparse_fwd (prefill) - -Note: FlashInfer's sparse.py provides BlockSparseAttentionWrapper which uses BSR format, -different from DeepSeek's DSA token-level sparse attention. -""" - -import math -from pathlib import Path - -import numpy as np -import pytest -import torch - -# Ground truth imports with availability checks -try: - from sgl_kernel.flash_mla import flash_mla_sparse_fwd, flash_mla_with_kvcache, get_mla_metadata - - SGLANG_AVAILABLE = True -except ImportError: - SGLANG_AVAILABLE = False - -# FlashInfer sparse is BSR-based, different from DSA's token-level sparse -try: - import flashinfer - - FLASHINFER_AVAILABLE = True -except ImportError: - FLASHINFER_AVAILABLE = False - -# Module-level constants (DeepSeek V3/R1 with TP=8) -NUM_QO_HEADS = 16 -HEAD_DIM_CKV = 512 -HEAD_DIM_KPE = 64 -PAGE_SIZE = 1 -TOPK = 256 - -TRACE_ROOT = Path(__file__).resolve().parents[2] - - -@torch.no_grad() -def run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale): - """Reference implementation for DSA sparse decode attention.""" - batch_size, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - page_size = ckv_cache.shape[1] - topk = sparse_indices.shape[-1] - - # Check constants - assert num_qo_heads == NUM_QO_HEADS - assert head_dim_ckv == HEAD_DIM_CKV - assert head_dim_kpe == HEAD_DIM_KPE - assert page_size == PAGE_SIZE - assert topk == TOPK - - device = q_nope.device - - # Squeeze page dimension (page_size=1) - Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv] - Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe] - - output = torch.zeros( - (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device - ) - lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - for b in range(batch_size): - indices = sparse_indices[b] # [topk] - - # Handle padding: -1 indicates invalid indices - valid_mask = indices != -1 - valid_indices = indices[valid_mask] - - if valid_indices.numel() == 0: - output[b].zero_() - continue - - tok_idx = valid_indices.to(torch.long) - - Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv] - Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe] - qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv] - qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe] - - # Compute attention logits - logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid] - logits_scaled = logits * sm_scale - - # Compute 2-base LSE - lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - # Compute attention output - attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid] - out = attn @ Kc # [num_qo_heads, head_dim_ckv] - output[b] = out.to(torch.bfloat16) - - return {"output": output, "lse": lse} - - -def generate_random_inputs( - batch_size, - max_seq_len, - num_qo_heads=NUM_QO_HEADS, - head_dim_ckv=HEAD_DIM_CKV, - head_dim_kpe=HEAD_DIM_KPE, - topk=TOPK, - device="cuda", -): - """Generate random inputs for DSA sparse attention testing.""" - # Generate random sequence lengths for each batch - # Ensure seq_lens >= topk so we have enough tokens to select - min_seq_len = max(topk, 256) - seq_lens = torch.randint( - min_seq_len, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device - ) - - # Calculate total pages needed - total_pages_needed = seq_lens.sum().item() - - # Generate page table (mapping sequence positions to page indices) - # For simplicity, use consecutive pages - page_table = torch.zeros(batch_size, max_seq_len, dtype=torch.int32, device=device) - page_offset = 0 - for b in range(batch_size): - seq_len = seq_lens[b].item() - page_table[b, :seq_len] = torch.arange( - page_offset, page_offset + seq_len, dtype=torch.int32, device=device - ) - page_offset += seq_len - - # Generate sparse indices (top-K selection for each batch element) - sparse_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device) - for b in range(batch_size): - seq_len = seq_lens[b].item() - actual_topk = min(topk, seq_len) - # Select random indices from available pages - perm = torch.randperm(seq_len, device=device)[:actual_topk] - selected_pages = page_table[b, perm] - sparse_indices[b, :actual_topk] = selected_pages.to(torch.int32) - - # Generate query tensors - q_nope = torch.randn( - batch_size, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn(batch_size, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) - - # Generate compressed KV and positional caches - num_pages = total_pages_needed + 100 # Add extra pages - ckv_cache = torch.randn(num_pages, 1, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, 1, head_dim_kpe, dtype=torch.bfloat16, device=device) - - # Generate softmax scale - # MLA uses head dimension before matrix absorption (128 + 64 = 192) - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) - - return { - "q_nope": q_nope, - "q_pe": q_pe, - "ckv_cache": ckv_cache, - "kpe_cache": kpe_cache, - "sparse_indices": sparse_indices, - "sm_scale": torch.tensor(sm_scale, dtype=torch.float32, device=device), - "seq_lens": seq_lens, - "page_table": page_table, - "num_pages": num_pages, - } - - -def compute_error_metrics(ref, gt, name="output"): - """Compute and print detailed error metrics.""" - ref_f32 = ref.float() - gt_f32 = gt.float() - - abs_diff = torch.abs(ref_f32 - gt_f32) - rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\n{name} comparison:") - print(f" Max absolute difference: {max_abs_diff:.6e}") - print(f" Max relative difference: {max_rel_diff:.6e}") - print(f" Mean absolute difference: {mean_abs_diff:.6e}") - print(f" Mean relative difference: {mean_rel_diff:.6e}") - - # Cosine similarity and MSE - cos_sim = torch.nn.functional.cosine_similarity( - ref_f32.flatten(), gt_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_f32 - gt_f32) ** 2).item() - print(f" Cosine similarity: {cos_sim:.6f}") - print(f" MSE: {mse:.6e}") - - return abs_diff, rel_diff, cos_sim - - -def check_hit_ratio(ref, gt, atol, rtol, required_percent=0.85): - """Check if hit ratio meets threshold (for quantized kernels).""" - ref_f32 = ref.float() - gt_f32 = gt.float() - - left = (ref_f32 - gt_f32).abs() - right = atol + rtol * gt_f32.abs() - ok = left <= right - hit_ratio = ok.float().mean().item() - - print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {required_percent * 100:.2f}%)") - return hit_ratio >= required_percent - - -def test_output_shape(batch_size=4, max_seq_len=512, topk=TOPK): - """Test that reference produces correct output shapes.""" - print(f"\n{'='*60}") - print(f"Testing DSA output shape: batch_size={batch_size}, topk={topk}") - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - inputs = generate_random_inputs(batch_size, max_seq_len, topk=topk, device=device) - - result = run( - inputs["q_nope"], - inputs["q_pe"], - inputs["ckv_cache"], - inputs["kpe_cache"], - inputs["sparse_indices"], - inputs["sm_scale"], - ) - - output = result["output"] - lse = result["lse"] - - expected_output_shape = (batch_size, NUM_QO_HEADS, HEAD_DIM_CKV) - expected_lse_shape = (batch_size, NUM_QO_HEADS) - - output_shape_correct = output.shape == expected_output_shape - lse_shape_correct = lse.shape == expected_lse_shape - - print(f"Output shape: {output.shape} (expected: {expected_output_shape})") - print(f"LSE shape: {lse.shape} (expected: {expected_lse_shape})") - - if output_shape_correct and lse_shape_correct: - print("PASSED: Output shapes are correct") - return True - else: - print("FAILED: Output shapes are incorrect") - return False - - -def test_sparse_vs_dense_consistency(batch_size=4, topk=TOPK): - """Test that sparse attention with all tokens selected equals dense attention.""" - print(f"\n{'='*60}") - print(f"Testing DSA sparse vs dense consistency") - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - # Generate inputs where sparse_indices includes all tokens (no sparsity) - # Use a small sequence length equal to topk for full coverage - seq_len = topk - num_pages = seq_len + 10 - - q_nope = torch.randn( - batch_size, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - ckv_cache = torch.randn(num_pages, 1, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, 1, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - sm_scale = torch.tensor(1.0 / np.sqrt(128 + HEAD_DIM_KPE), dtype=torch.float32, device=device) - - # All indices valid (0 to seq_len-1) - sparse_indices = ( - torch.arange(seq_len, dtype=torch.int32, device=device) - .unsqueeze(0) - .expand(batch_size, -1) - .contiguous() - ) - - result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - output = result["output"] - lse = result["lse"] - - # Check that output is not all zeros (actually computed) - output_nonzero = output.abs().sum() > 0 - lse_finite = torch.all(torch.isfinite(lse)) - - print(f"Output non-zero: {output_nonzero}") - print(f"LSE finite: {lse_finite}") - - if output_nonzero and lse_finite: - print("PASSED: Sparse attention produces valid outputs") - return True - else: - print("FAILED: Sparse attention produces invalid outputs") - return False - - -def test_padding_handling(batch_size=4, topk=TOPK): - """Test that padding (-1 indices) are handled correctly.""" - print(f"\n{'='*60}") - print(f"Testing DSA padding handling") - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - num_pages = 1000 - - q_nope = torch.randn( - batch_size, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - ckv_cache = torch.randn(num_pages, 1, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, 1, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - sm_scale = torch.tensor(1.0 / np.sqrt(128 + HEAD_DIM_KPE), dtype=torch.float32, device=device) - - # Create sparse indices with varying amounts of padding - sparse_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device) - valid_counts = [topk, topk // 2, topk // 4, 10] # Different valid counts per batch - - for b in range(batch_size): - valid_count = valid_counts[b % len(valid_counts)] - sparse_indices[b, :valid_count] = torch.randint( - 0, num_pages, (valid_count,), dtype=torch.int32, device=device - ) - - result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - output = result["output"] - lse = result["lse"] - - # Verify outputs are valid - output_valid = not torch.isnan(output).any() and not torch.isinf(output).any() - # LSE can be -inf for empty sequences, but should not be +inf or nan - lse_valid = not torch.isnan(lse).any() and not torch.isinf(lse[lse > -float("inf")]).any() - - print(f"Output valid (no nan/inf): {output_valid}") - print(f"LSE valid: {lse_valid}") - - if output_valid and lse_valid: - print("PASSED: Padding handled correctly") - return True - else: - print("FAILED: Padding handling issue") - return False - - -def test_correctness_vs_sglang(batch_size=4, max_seq_len=512, atol=1e-2, rtol=5e-2): - """ - Test correctness against SGLang FlashMLA sparse kernel. - - NOTE: This test requires SGLang sgl_kernel to be installed. - If not available, the test will be skipped. - """ - print(f"\n{'='*60}") - print(f"Testing DSA correctness against SGLang FlashMLA") - print(f"batch_size={batch_size}, max_seq_len={max_seq_len}") - print(f"{'='*60}") - - if not SGLANG_AVAILABLE: - print("SKIPPED: SGLang/sgl_kernel not available") - return None - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("SKIPPED: CUDA not available") - return None - - torch.manual_seed(42) - - # Test parameters - num_pages = 1024 - head_dim = HEAD_DIM_CKV + HEAD_DIM_KPE # Combined head dim = 576 - - # Determine required head padding based on GPU architecture - # FlashMLA kernel requires h_q to be multiple of 64 (Hopper SM90) or 128 (Blackwell SM100+) - device_sm_major = torch.cuda.get_device_properties(device).major - required_padding = 128 if device_sm_major >= 10 else 64 - print(f"GPU SM major: {device_sm_major}, required head padding: {required_padding}") - - # Generate query tensors - q_nope = torch.randn( - batch_size, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - - # Combined q for FlashMLA: [s_q, h_q, d_qk] - # Note: flash_mla_sparse_fwd expects [s_q, h_q, d_qk], not [batch, h, d] - q_all = torch.cat([q_nope, q_pe], dim=-1) # [batch_size, num_qo_heads, head_dim] - - # KV cache (combined) - # flash_mla_sparse_fwd expects: kv [s_kv, h_kv, d_qk] where h_kv=1 for MLA - kv_cache = torch.randn(num_pages, head_dim, dtype=torch.bfloat16, device=device) - ckv_cache = kv_cache[:, :HEAD_DIM_CKV].unsqueeze(1) # [num_pages, 1, ckv] - kpe_cache = kv_cache[:, HEAD_DIM_CKV:].unsqueeze(1) # [num_pages, 1, kpe] - - sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) - - # Generate sparse indices: [batch_size, topk] for reference - sparse_indices = torch.randint( - 0, num_pages, (batch_size, TOPK), dtype=torch.int32, device=device - ) - - # Run reference implementation - print("Running reference implementation...") - ref_result = run( - q_nope, - q_pe, - ckv_cache, - kpe_cache, - sparse_indices, - torch.tensor(sm_scale, dtype=torch.float32, device=device), - ) - ref_output = ref_result["output"] - ref_lse = ref_result["lse"] - - # Run FlashMLA sparse - # flash_mla_sparse_fwd expects: - # q: [s_q, h_q, d_qk] bfloat16 - h_q must be multiple of 64 (SM90) or 128 (SM100+) - # kv: [s_kv, h_kv, d_qk] bfloat16 (h_kv=1 for MLA) - # indices: [s_q, h_kv, topk] int32 - print("Running SGLang FlashMLA sparse...") - try: - kv_for_mla = kv_cache.unsqueeze(1) # [s_kv, 1, d_qk] - - # indices: [s_q, h_kv=1, topk] - indices_for_mla = sparse_indices.unsqueeze(1) # [batch_size, 1, topk] - - # Pad query heads to required multiple (64 or 128) as done in SGLang's dsa_backend.py - need_padding = NUM_QO_HEADS % required_padding != 0 - if need_padding: - assert ( - required_padding % NUM_QO_HEADS == 0 - ), f"required_padding ({required_padding}) must be divisible by NUM_QO_HEADS ({NUM_QO_HEADS})" - q_padded = q_all.new_zeros((batch_size, required_padding, head_dim)) - q_padded[:, :NUM_QO_HEADS, :] = q_all - q_input = q_padded - print(f"Padded q from {NUM_QO_HEADS} to {required_padding} heads") - else: - q_input = q_all - - fi_output_full, fi_max_logits, fi_lse_full = flash_mla_sparse_fwd( - q=q_input, kv=kv_for_mla, indices=indices_for_mla, sm_scale=sm_scale, d_v=HEAD_DIM_CKV - ) - - # Trim output back to original number of heads if padding was applied - if need_padding: - fi_output = fi_output_full[:, :NUM_QO_HEADS, :] - fi_lse = fi_lse_full[:, :NUM_QO_HEADS] - else: - fi_output = fi_output_full - fi_lse = fi_lse_full - - except Exception as e: - print(f"WARNING: FlashMLA sparse fwd failed: {e}") - print("This may be due to API differences - skipping SGLang test") - import traceback - - traceback.print_exc() - return None - - # Compare outputs - print("\nComparing outputs...") - abs_diff, rel_diff, cos_sim = compute_error_metrics(ref_output, fi_output, "output") - - # Check tolerance - allclose = torch.allclose(ref_output.float(), fi_output.float(), atol=atol, rtol=rtol) - - if allclose: - print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") - return True - else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - # Show top error locations - flat = (ref_output.float() - fi_output.float()).abs().flatten() - k = min(5, flat.numel()) - topv, topi = torch.topk(flat, k) - print(f"\nTop-{k} absolute error locations:") - for rank in range(k): - idx = topi[rank].item() - print( - f" idx={idx}: ref={ref_output.flatten()[idx].item():.6e}, " - f"fi={fi_output.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}" - ) - - # Use hit ratio as secondary check - passed = check_hit_ratio(ref_output, fi_output, atol, rtol, required_percent=0.85) - return passed - - -def main(): - """Run comprehensive tests.""" - print("Testing DSA (DeepSeek Sparse Attention) Sparse Decode Reference Implementation") - print("=" * 70) - print( - f"Constants: h={NUM_QO_HEADS}, ckv={HEAD_DIM_CKV}, kpe={HEAD_DIM_KPE}, ps={PAGE_SIZE}, topk={TOPK}" - ) - print(f"SGLang available: {SGLANG_AVAILABLE}") - print(f"FlashInfer available: {FLASHINFER_AVAILABLE}") - print("=" * 70) - - test_results = [] - - # Basic functionality tests - test_results.append(("output_shape", test_output_shape())) - test_results.append(("sparse_vs_dense", test_sparse_vs_dense_consistency())) - test_results.append(("padding_handling", test_padding_handling())) - - # Ground truth comparison tests - test_configs = [(1, 512), (4, 512), (8, 1024)] # Single batch # Small batch # Medium batch - - for batch_size, max_seq_len in test_configs: - name = f"sglang_bs{batch_size}_seq{max_seq_len}" - try: - result = test_correctness_vs_sglang(batch_size, max_seq_len) - test_results.append((name, result)) - except Exception as e: - print(f"\n✗ Test {name} crashed: {e}") - import traceback - - traceback.print_exc() - test_results.append((name, False)) - - # Summary - print(f"\n{'='*70}") - print("Test Summary:") - print(f"{'='*70}") - - passed = 0 - skipped = 0 - failed = 0 - - for name, result in test_results: - if result is None: - status = "SKIPPED" - skipped += 1 - elif result: - status = "PASSED" - passed += 1 - else: - status = "FAILED" - failed += 1 - print(f" {name}: {status}") - - print(f"\nTotal: {passed} passed, {failed} failed, {skipped} skipped") - - if failed == 0: - print("\n✓ All tests passed!") - else: - print(f"\n✗ {failed} tests failed") - - -if __name__ == "__main__": - main() diff --git a/flashinfer_trace/tests/references/test_dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64.py b/flashinfer_trace/tests/references/test_dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64.py deleted file mode 100644 index 4f0dc2e3..00000000 --- a/flashinfer_trace/tests/references/test_dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64.py +++ /dev/null @@ -1,588 +0,0 @@ -""" -Tests for DSA (DeepSeek Sparse Attention) sparse decode reference implementation. -Page size 64 variant. - -Ground truth sources: -1. SGLang FlashMLA sparse kernel: sgl_kernel.flash_mla.flash_mla_with_kvcache (decode with indices) -2. SGLang FlashMLA sparse prefill: sgl_kernel.flash_mla.flash_mla_sparse_fwd (prefill) - -Note: FlashInfer's sparse.py provides BlockSparseAttentionWrapper which uses BSR format, -different from DeepSeek's DSA token-level sparse attention. -""" - -import math -from pathlib import Path - -import numpy as np -import pytest -import torch - -# Ground truth imports with availability checks -try: - from sgl_kernel.flash_mla import flash_mla_sparse_fwd, flash_mla_with_kvcache, get_mla_metadata - - SGLANG_AVAILABLE = True -except ImportError: - SGLANG_AVAILABLE = False - -# FlashInfer sparse is BSR-based, different from DSA's token-level sparse -try: - import flashinfer - - FLASHINFER_AVAILABLE = True -except ImportError: - FLASHINFER_AVAILABLE = False - -# Module-level constants (DeepSeek V3/R1 with TP=8) -NUM_QO_HEADS = 16 -HEAD_DIM_CKV = 512 -HEAD_DIM_KPE = 64 -PAGE_SIZE = 64 -TOPK = 256 - -TRACE_ROOT = Path(__file__).resolve().parents[2] - - -@torch.no_grad() -def run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale): - """Reference implementation for DSA sparse decode attention with page_size=64.""" - batch_size, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - num_pages, page_size, _ = ckv_cache.shape - topk = sparse_indices.shape[-1] - - # Check constants - assert num_qo_heads == NUM_QO_HEADS - assert head_dim_ckv == HEAD_DIM_CKV - assert head_dim_kpe == HEAD_DIM_KPE - assert page_size == PAGE_SIZE - assert topk == TOPK - - # Check constraints - assert sparse_indices.shape[-1] == topk - assert ckv_cache.shape[1] == page_size - - device = q_nope.device - - # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim] - Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to( - torch.float32 - ) # [total_kv_tokens, head_dim_ckv] - Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to( - torch.float32 - ) # [total_kv_tokens, head_dim_kpe] - - output = torch.zeros( - (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device - ) - lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - for b in range(batch_size): - indices = sparse_indices[b] # [topk] - - # Handle padding: -1 indicates invalid indices - valid_mask = indices != -1 - valid_indices = indices[valid_mask] - - if valid_indices.numel() == 0: - output[b].zero_() - continue - - # For page_size=64, indices encode (page_idx * 64 + offset) - tok_idx = valid_indices.to(torch.long) - - Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv] - Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe] - qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv] - qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe] - - # Compute attention logits - logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid] - logits_scaled = logits * sm_scale - - # Compute 2-base LSE - lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - # Compute attention output - attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid] - out = attn @ Kc # [num_qo_heads, head_dim_ckv] - output[b] = out.to(torch.bfloat16) - - return {"output": output, "lse": lse} - - -def generate_random_inputs( - batch_size, - max_seq_len, - num_qo_heads=NUM_QO_HEADS, - head_dim_ckv=HEAD_DIM_CKV, - head_dim_kpe=HEAD_DIM_KPE, - topk=TOPK, - device="cuda", -): - """Generate random inputs for DSA sparse attention testing with page_size=64.""" - # Generate random sequence lengths for each batch - # Ensure seq_lens >= topk so we have enough tokens to select - min_seq_len = max(topk, 256) - seq_lens = torch.randint( - min_seq_len, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device - ) - - # Calculate total pages needed (each page holds 64 tokens) - total_tokens_needed = seq_lens.sum().item() - total_pages_needed = (total_tokens_needed + PAGE_SIZE - 1) // PAGE_SIZE - - # Generate page table (mapping sequence positions to global token indices) - # For page_size=64, page_table[b, pos] = page_idx * 64 + offset - page_table = torch.zeros(batch_size, max_seq_len, dtype=torch.int32, device=device) - token_offset = 0 - for b in range(batch_size): - seq_len = seq_lens[b].item() - page_table[b, :seq_len] = torch.arange( - token_offset, token_offset + seq_len, dtype=torch.int32, device=device - ) - token_offset += seq_len - - # Generate sparse indices (top-K selection for each batch element) - # Indices are global token indices: page_idx * 64 + offset - sparse_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device) - for b in range(batch_size): - seq_len = seq_lens[b].item() - actual_topk = min(topk, seq_len) - # Select random indices from available token positions - perm = torch.randperm(seq_len, device=device)[:actual_topk] - selected_tokens = page_table[b, perm] - sparse_indices[b, :actual_topk] = selected_tokens.to(torch.int32) - - # Generate query tensors - q_nope = torch.randn( - batch_size, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn(batch_size, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) - - # Generate compressed KV and positional caches with page_size=64 - num_pages = total_pages_needed + 10 # Add extra pages - ckv_cache = torch.randn(num_pages, PAGE_SIZE, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, PAGE_SIZE, head_dim_kpe, dtype=torch.bfloat16, device=device) - - # Generate softmax scale - # MLA uses head dimension before matrix absorption (128 + 64 = 192) - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) - - return { - "q_nope": q_nope, - "q_pe": q_pe, - "ckv_cache": ckv_cache, - "kpe_cache": kpe_cache, - "sparse_indices": sparse_indices, - "sm_scale": torch.tensor(sm_scale, dtype=torch.float32, device=device), - "seq_lens": seq_lens, - "page_table": page_table, - "num_pages": num_pages, - } - - -def compute_error_metrics(ref, gt, name="output"): - """Compute and print detailed error metrics.""" - ref_f32 = ref.float() - gt_f32 = gt.float() - - abs_diff = torch.abs(ref_f32 - gt_f32) - rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\n{name} comparison:") - print(f" Max absolute difference: {max_abs_diff:.6e}") - print(f" Max relative difference: {max_rel_diff:.6e}") - print(f" Mean absolute difference: {mean_abs_diff:.6e}") - print(f" Mean relative difference: {mean_rel_diff:.6e}") - - # Cosine similarity and MSE - cos_sim = torch.nn.functional.cosine_similarity( - ref_f32.flatten(), gt_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_f32 - gt_f32) ** 2).item() - print(f" Cosine similarity: {cos_sim:.6f}") - print(f" MSE: {mse:.6e}") - - return abs_diff, rel_diff, cos_sim - - -def check_hit_ratio(ref, gt, atol, rtol, required_percent=0.85): - """Check if hit ratio meets threshold (for quantized kernels).""" - ref_f32 = ref.float() - gt_f32 = gt.float() - - left = (ref_f32 - gt_f32).abs() - right = atol + rtol * gt_f32.abs() - ok = left <= right - hit_ratio = ok.float().mean().item() - - print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {required_percent * 100:.2f}%)") - return hit_ratio >= required_percent - - -def test_output_shape(batch_size=4, max_seq_len=512, topk=TOPK): - """Test that reference produces correct output shapes.""" - print(f"\n{'='*60}") - print(f"Testing DSA decode ps64 output shape: batch_size={batch_size}, topk={topk}") - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - inputs = generate_random_inputs(batch_size, max_seq_len, topk=topk, device=device) - - result = run( - inputs["q_nope"], - inputs["q_pe"], - inputs["ckv_cache"], - inputs["kpe_cache"], - inputs["sparse_indices"], - inputs["sm_scale"], - ) - - output = result["output"] - lse = result["lse"] - - expected_output_shape = (batch_size, NUM_QO_HEADS, HEAD_DIM_CKV) - expected_lse_shape = (batch_size, NUM_QO_HEADS) - - output_shape_correct = output.shape == expected_output_shape - lse_shape_correct = lse.shape == expected_lse_shape - - print(f"Output shape: {output.shape} (expected: {expected_output_shape})") - print(f"LSE shape: {lse.shape} (expected: {expected_lse_shape})") - - if output_shape_correct and lse_shape_correct: - print("PASSED: Output shapes are correct") - return True - else: - print("FAILED: Output shapes are incorrect") - return False - - -def test_sparse_vs_dense_consistency(batch_size=4, topk=TOPK): - """Test that sparse attention with all tokens selected equals dense attention.""" - print(f"\n{'='*60}") - print(f"Testing DSA decode ps64 sparse vs dense consistency") - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - # Generate inputs where sparse_indices includes all tokens (no sparsity) - # Use a small sequence length equal to topk for full coverage - seq_len = topk - num_pages = (seq_len + PAGE_SIZE - 1) // PAGE_SIZE + 1 - - q_nope = torch.randn( - batch_size, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - ckv_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - sm_scale = torch.tensor(1.0 / np.sqrt(128 + HEAD_DIM_KPE), dtype=torch.float32, device=device) - - # All indices valid (0 to seq_len-1) - global token indices - sparse_indices = ( - torch.arange(seq_len, dtype=torch.int32, device=device) - .unsqueeze(0) - .expand(batch_size, -1) - .contiguous() - ) - - result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - output = result["output"] - lse = result["lse"] - - # Check that output is not all zeros (actually computed) - output_nonzero = output.abs().sum() > 0 - lse_finite = torch.all(torch.isfinite(lse)) - - print(f"Output non-zero: {output_nonzero}") - print(f"LSE finite: {lse_finite}") - - if output_nonzero and lse_finite: - print("PASSED: Sparse attention produces valid outputs") - return True - else: - print("FAILED: Sparse attention produces invalid outputs") - return False - - -def test_padding_handling(batch_size=4, topk=TOPK): - """Test that padding (-1 indices) are handled correctly.""" - print(f"\n{'='*60}") - print(f"Testing DSA decode ps64 padding handling") - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - num_pages = 64 # 64 pages * 64 tokens = 4096 total tokens - total_tokens_in_cache = num_pages * PAGE_SIZE - - q_nope = torch.randn( - batch_size, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - ckv_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - sm_scale = torch.tensor(1.0 / np.sqrt(128 + HEAD_DIM_KPE), dtype=torch.float32, device=device) - - # Create sparse indices with varying amounts of padding - sparse_indices = torch.full((batch_size, topk), -1, dtype=torch.int32, device=device) - valid_counts = [topk, topk // 2, topk // 4, 10] # Different valid counts per batch - - for b in range(batch_size): - valid_count = valid_counts[b % len(valid_counts)] - sparse_indices[b, :valid_count] = torch.randint( - 0, total_tokens_in_cache, (valid_count,), dtype=torch.int32, device=device - ) - - result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - output = result["output"] - lse = result["lse"] - - # Verify outputs are valid - output_valid = not torch.isnan(output).any() and not torch.isinf(output).any() - # LSE can be -inf for empty sequences, but should not be +inf or nan - lse_valid = not torch.isnan(lse).any() and not torch.isinf(lse[lse > -float("inf")]).any() - - print(f"Output valid (no nan/inf): {output_valid}") - print(f"LSE valid: {lse_valid}") - - if output_valid and lse_valid: - print("PASSED: Padding handled correctly") - return True - else: - print("FAILED: Padding handling issue") - return False - - -def test_correctness_vs_sglang(batch_size=4, max_seq_len=512, atol=1e-2, rtol=5e-2): - """ - Test correctness against SGLang FlashMLA sparse kernel. - - NOTE: This test requires SGLang sgl_kernel to be installed. - If not available, the test will be skipped. - - NOTE: FlashMLA sparse kernel operates at token-level (page_size=1). - For page_size=64, we flatten to token-level for comparison. - """ - print(f"\n{'='*60}") - print(f"Testing DSA decode ps64 correctness against SGLang FlashMLA") - print(f"batch_size={batch_size}, max_seq_len={max_seq_len}") - print(f"{'='*60}") - - if not SGLANG_AVAILABLE: - print("SKIPPED: SGLang/sgl_kernel not available") - return None - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("SKIPPED: CUDA not available") - return None - - torch.manual_seed(42) - - # Test parameters - num_pages = 32 # 32 pages * 64 tokens/page = 2048 total KV tokens - total_kv_tokens = num_pages * PAGE_SIZE - head_dim = HEAD_DIM_CKV + HEAD_DIM_KPE # Combined head dim = 576 - - # Determine required head padding based on GPU architecture - # FlashMLA kernel requires h_q to be multiple of 64 (Hopper SM90) or 128 (Blackwell SM100+) - device_sm_major = torch.cuda.get_device_properties(device).major - required_padding = 128 if device_sm_major >= 10 else 64 - print(f"GPU SM major: {device_sm_major}, required head padding: {required_padding}") - - # Generate query tensors - q_nope = torch.randn( - batch_size, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - - # Combined q for FlashMLA: [s_q, h_q, d_qk] - q_all = torch.cat([q_nope, q_pe], dim=-1) # [batch_size, num_qo_heads, head_dim] - - # KV cache with page_size=64 - # Shape: [num_pages, page_size, head_dim] - kv_cache_paged = torch.randn( - num_pages, PAGE_SIZE, head_dim, dtype=torch.bfloat16, device=device - ) - ckv_cache = kv_cache_paged[:, :, :HEAD_DIM_CKV] # [num_pages, 64, ckv] - kpe_cache = kv_cache_paged[:, :, HEAD_DIM_CKV:] # [num_pages, 64, kpe] - - # Flatten for FlashMLA (token-level) - kv_cache_flat = kv_cache_paged.reshape(total_kv_tokens, head_dim) # [total_kv_tokens, head_dim] - - sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) - - # Generate sparse indices as global token indices: [batch_size, topk] - # Each index is in range [0, total_kv_tokens) - sparse_indices = torch.randint( - 0, total_kv_tokens, (batch_size, TOPK), dtype=torch.int32, device=device - ) - - # Run reference implementation with page_size=64 - print("Running reference implementation (page_size=64)...") - ref_result = run( - q_nope, - q_pe, - ckv_cache, - kpe_cache, - sparse_indices, - torch.tensor(sm_scale, dtype=torch.float32, device=device), - ) - ref_output = ref_result["output"] - ref_lse = ref_result["lse"] - - # Run FlashMLA sparse (token-level) - # flash_mla_sparse_fwd expects: - # q: [s_q, h_q, d_qk] bfloat16 - h_q must be multiple of 64 (SM90) or 128 (SM100+) - # kv: [s_kv, h_kv, d_qk] bfloat16 (h_kv=1 for MLA) - # indices: [s_q, h_kv, topk] int32 - print("Running SGLang FlashMLA sparse (token-level)...") - try: - kv_for_mla = kv_cache_flat.unsqueeze(1) # [s_kv, 1, d_qk] - - # indices: [s_q, h_kv=1, topk] - indices_for_mla = sparse_indices.unsqueeze(1) # [batch_size, 1, topk] - - # Pad query heads to required multiple (64 or 128) as done in SGLang's dsa_backend.py - need_padding = NUM_QO_HEADS % required_padding != 0 - if need_padding: - assert ( - required_padding % NUM_QO_HEADS == 0 - ), f"required_padding ({required_padding}) must be divisible by NUM_QO_HEADS ({NUM_QO_HEADS})" - q_padded = q_all.new_zeros((batch_size, required_padding, head_dim)) - q_padded[:, :NUM_QO_HEADS, :] = q_all - q_input = q_padded - print(f"Padded q from {NUM_QO_HEADS} to {required_padding} heads") - else: - q_input = q_all - - fi_output_full, fi_max_logits, fi_lse_full = flash_mla_sparse_fwd( - q=q_input, kv=kv_for_mla, indices=indices_for_mla, sm_scale=sm_scale, d_v=HEAD_DIM_CKV - ) - - # Trim output back to original number of heads if padding was applied - if need_padding: - fi_output = fi_output_full[:, :NUM_QO_HEADS, :] - fi_lse = fi_lse_full[:, :NUM_QO_HEADS] - else: - fi_output = fi_output_full - fi_lse = fi_lse_full - - except Exception as e: - print(f"WARNING: FlashMLA sparse fwd failed: {e}") - print("This may be due to API differences - skipping SGLang test") - import traceback - - traceback.print_exc() - return None - - # Compare outputs - print("\nComparing outputs...") - abs_diff, rel_diff, cos_sim = compute_error_metrics(ref_output, fi_output, "output") - - # Check tolerance - allclose = torch.allclose(ref_output.float(), fi_output.float(), atol=atol, rtol=rtol) - - if allclose: - print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") - return True - else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - # Show top error locations - flat = (ref_output.float() - fi_output.float()).abs().flatten() - k = min(5, flat.numel()) - topv, topi = torch.topk(flat, k) - print(f"\nTop-{k} absolute error locations:") - for rank in range(k): - idx = topi[rank].item() - print( - f" idx={idx}: ref={ref_output.flatten()[idx].item():.6e}, " - f"fi={fi_output.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}" - ) - - # Use hit ratio as secondary check - passed = check_hit_ratio(ref_output, fi_output, atol, rtol, required_percent=0.85) - return passed - - -def main(): - """Run comprehensive tests.""" - print("Testing DSA (DeepSeek Sparse Attention) Sparse Decode Reference Implementation") - print("Page Size 64 Variant") - print("=" * 70) - print( - f"Constants: h={NUM_QO_HEADS}, ckv={HEAD_DIM_CKV}, kpe={HEAD_DIM_KPE}, ps={PAGE_SIZE}, topk={TOPK}" - ) - print(f"SGLang available: {SGLANG_AVAILABLE}") - print(f"FlashInfer available: {FLASHINFER_AVAILABLE}") - print("=" * 70) - - test_results = [] - - # Basic functionality tests - test_results.append(("output_shape", test_output_shape())) - test_results.append(("sparse_vs_dense", test_sparse_vs_dense_consistency())) - test_results.append(("padding_handling", test_padding_handling())) - - # Ground truth comparison tests - test_configs = [(1, 512), (4, 512), (8, 1024)] # Single batch # Small batch # Medium batch - - for batch_size, max_seq_len in test_configs: - name = f"sglang_bs{batch_size}_seq{max_seq_len}" - try: - result = test_correctness_vs_sglang(batch_size, max_seq_len) - test_results.append((name, result)) - except Exception as e: - print(f"\n✗ Test {name} crashed: {e}") - import traceback - - traceback.print_exc() - test_results.append((name, False)) - - # Summary - print(f"\n{'='*70}") - print("Test Summary:") - print(f"{'='*70}") - - passed = 0 - skipped = 0 - failed = 0 - - for name, result in test_results: - if result is None: - status = "SKIPPED" - skipped += 1 - elif result: - status = "PASSED" - passed += 1 - else: - status = "FAILED" - failed += 1 - print(f" {name}: {status}") - - print(f"\nTotal: {passed} passed, {failed} failed, {skipped} skipped") - - if failed == 0: - print("\n✓ All tests passed!") - else: - print(f"\n✗ {failed} tests failed") - - -if __name__ == "__main__": - main() diff --git a/flashinfer_trace/tests/references/test_dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps1.py b/flashinfer_trace/tests/references/test_dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps1.py deleted file mode 100644 index 2ea6099a..00000000 --- a/flashinfer_trace/tests/references/test_dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps1.py +++ /dev/null @@ -1,497 +0,0 @@ -""" -Tests for DSA (DeepSeek Sparse Attention) sparse prefill reference implementation. - -Ground truth sources: -1. SGLang FlashMLA sparse prefill: sgl_kernel.flash_mla.flash_mla_sparse_fwd - -Note: FlashInfer's sparse.py provides BlockSparseAttentionWrapper which uses BSR format, -different from DeepSeek's DSA token-level sparse attention. -""" - -import math -from pathlib import Path - -import numpy as np -import pytest -import torch - -# Ground truth imports with availability checks -try: - from sgl_kernel.flash_mla import flash_mla_sparse_fwd - - SGLANG_AVAILABLE = True -except ImportError: - SGLANG_AVAILABLE = False - -try: - import flashinfer - - FLASHINFER_AVAILABLE = True -except ImportError: - FLASHINFER_AVAILABLE = False - -# Module-level constants (DeepSeek V3/R1 with TP=8) -NUM_QO_HEADS = 16 -HEAD_DIM_CKV = 512 -HEAD_DIM_KPE = 64 -PAGE_SIZE = 1 -TOPK = 256 - -TRACE_ROOT = Path(__file__).resolve().parents[2] - - -@torch.no_grad() -def run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale): - """Reference implementation for DSA sparse prefill attention.""" - total_num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - page_size = ckv_cache.shape[1] - topk = sparse_indices.shape[-1] - - # Check constants - assert num_qo_heads == NUM_QO_HEADS - assert head_dim_ckv == HEAD_DIM_CKV - assert head_dim_kpe == HEAD_DIM_KPE - assert page_size == PAGE_SIZE - assert topk == TOPK - - # Check constraints - assert sparse_indices.shape[0] == total_num_tokens - assert sparse_indices.shape[-1] == topk - assert ckv_cache.shape[1] == page_size - - device = q_nope.device - - # Squeeze page dimension (page_size=1) - Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv] - Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe] - - output = torch.zeros( - (total_num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device - ) - lse = torch.full( - (total_num_tokens, num_qo_heads), -float("inf"), dtype=torch.float32, device=device - ) - - for t in range(total_num_tokens): - indices = sparse_indices[t] # [topk] - - # Handle padding: -1 indicates invalid indices - valid_mask = indices != -1 - valid_indices = indices[valid_mask] - - if valid_indices.numel() == 0: - output[t].zero_() - continue - - tok_idx = valid_indices.to(torch.long) - - Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv] - Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe] - qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv] - qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe] - - # Compute attention logits - logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid] - logits_scaled = logits * sm_scale - - # Compute 2-base LSE - lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - # Compute attention output - attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid] - out = attn @ Kc # [num_qo_heads, head_dim_ckv] - output[t] = out.to(torch.bfloat16) - - return {"output": output, "lse": lse} - - -def generate_random_inputs( - total_num_tokens, - num_qo_heads=NUM_QO_HEADS, - head_dim_ckv=HEAD_DIM_CKV, - head_dim_kpe=HEAD_DIM_KPE, - topk=TOPK, - device="cuda", -): - """Generate random inputs for DSA sparse prefill attention testing.""" - # Generate KV cache with enough pages - num_pages = max(total_num_tokens * 2, 1024) - - # Generate sparse indices (top-K selection for each token) - sparse_indices = torch.randint( - 0, num_pages, (total_num_tokens, topk), dtype=torch.int32, device=device - ) - - # Generate query tensors - q_nope = torch.randn( - total_num_tokens, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn( - total_num_tokens, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device - ) - - # Generate compressed KV and positional caches - ckv_cache = torch.randn(num_pages, 1, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, 1, head_dim_kpe, dtype=torch.bfloat16, device=device) - - # Generate softmax scale - # MLA uses head dimension before matrix absorption (128 + 64 = 192) - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) - - return { - "q_nope": q_nope, - "q_pe": q_pe, - "ckv_cache": ckv_cache, - "kpe_cache": kpe_cache, - "sparse_indices": sparse_indices, - "sm_scale": torch.tensor(sm_scale, dtype=torch.float32, device=device), - "num_pages": num_pages, - } - - -def compute_error_metrics(ref, gt, name="output"): - """Compute and print detailed error metrics.""" - ref_f32 = ref.float() - gt_f32 = gt.float() - - abs_diff = torch.abs(ref_f32 - gt_f32) - rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\n{name} comparison:") - print(f" Max absolute difference: {max_abs_diff:.6e}") - print(f" Max relative difference: {max_rel_diff:.6e}") - print(f" Mean absolute difference: {mean_abs_diff:.6e}") - print(f" Mean relative difference: {mean_rel_diff:.6e}") - - # Cosine similarity and MSE - cos_sim = torch.nn.functional.cosine_similarity( - ref_f32.flatten(), gt_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_f32 - gt_f32) ** 2).item() - print(f" Cosine similarity: {cos_sim:.6f}") - print(f" MSE: {mse:.6e}") - - return abs_diff, rel_diff, cos_sim - - -def check_hit_ratio(ref, gt, atol, rtol, required_percent=0.85): - """Check if hit ratio meets threshold (for quantized kernels).""" - ref_f32 = ref.float() - gt_f32 = gt.float() - - left = (ref_f32 - gt_f32).abs() - right = atol + rtol * gt_f32.abs() - ok = left <= right - hit_ratio = ok.float().mean().item() - - print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {required_percent * 100:.2f}%)") - return hit_ratio >= required_percent - - -def test_output_shape(total_num_tokens=64, topk=TOPK): - """Test that reference produces correct output shapes.""" - print(f"\n{'='*60}") - print(f"Testing DSA prefill output shape: total_num_tokens={total_num_tokens}, topk={topk}") - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - inputs = generate_random_inputs(total_num_tokens, topk=topk, device=device) - - result = run( - inputs["q_nope"], - inputs["q_pe"], - inputs["ckv_cache"], - inputs["kpe_cache"], - inputs["sparse_indices"], - inputs["sm_scale"], - ) - - output = result["output"] - lse = result["lse"] - - expected_output_shape = (total_num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV) - expected_lse_shape = (total_num_tokens, NUM_QO_HEADS) - - output_shape_correct = output.shape == expected_output_shape - lse_shape_correct = lse.shape == expected_lse_shape - - print(f"Output shape: {output.shape} (expected: {expected_output_shape})") - print(f"LSE shape: {lse.shape} (expected: {expected_lse_shape})") - - if output_shape_correct and lse_shape_correct: - print("PASSED: Output shapes are correct") - return True - else: - print("FAILED: Output shapes are incorrect") - return False - - -def test_padding_handling(total_num_tokens=64, topk=TOPK): - """Test that padding (-1 indices) are handled correctly.""" - print(f"\n{'='*60}") - print(f"Testing DSA prefill padding handling") - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - num_pages = 1000 - - q_nope = torch.randn( - total_num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn( - total_num_tokens, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device - ) - ckv_cache = torch.randn(num_pages, 1, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, 1, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - sm_scale = torch.tensor(1.0 / np.sqrt(128 + HEAD_DIM_KPE), dtype=torch.float32, device=device) - - # Create sparse indices with varying amounts of padding per token - sparse_indices = torch.full((total_num_tokens, topk), -1, dtype=torch.int32, device=device) - - for t in range(total_num_tokens): - # Vary the number of valid indices - valid_count = (t % 4 + 1) * (topk // 4) # 25%, 50%, 75%, 100% - valid_count = min(valid_count, topk) - sparse_indices[t, :valid_count] = torch.randint( - 0, num_pages, (valid_count,), dtype=torch.int32, device=device - ) - - result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - output = result["output"] - lse = result["lse"] - - # Verify outputs are valid - output_valid = not torch.isnan(output).any() and not torch.isinf(output).any() - lse_valid = not torch.isnan(lse).any() and not torch.isinf(lse[lse > -float("inf")]).any() - - print(f"Output valid (no nan/inf): {output_valid}") - print(f"LSE valid: {lse_valid}") - - if output_valid and lse_valid: - print("PASSED: Padding handled correctly") - return True - else: - print("FAILED: Padding handling issue") - return False - - -def test_correctness_vs_sglang(total_num_tokens=64, atol=1e-2, rtol=5e-2): - """ - Test correctness against SGLang FlashMLA sparse prefill kernel. - - NOTE: This test requires SGLang sgl_kernel to be installed. - If not available, the test will be skipped. - """ - print(f"\n{'='*60}") - print(f"Testing DSA prefill correctness against SGLang FlashMLA") - print(f"total_num_tokens={total_num_tokens}") - print(f"{'='*60}") - - if not SGLANG_AVAILABLE: - print("SKIPPED: SGLang/sgl_kernel not available") - return None - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("SKIPPED: CUDA not available") - return None - - torch.manual_seed(42) - - # Test parameters - num_pages = 1024 - head_dim = HEAD_DIM_CKV + HEAD_DIM_KPE # Combined head dim = 576 - - # Determine required head padding based on GPU architecture - # FlashMLA kernel requires h_q to be multiple of 64 (Hopper SM90) or 128 (Blackwell SM100+) - device_sm_major = torch.cuda.get_device_properties(device).major - required_padding = 128 if device_sm_major >= 10 else 64 - print(f"GPU SM major: {device_sm_major}, required head padding: {required_padding}") - - # Generate query tensors - q_nope = torch.randn( - total_num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn( - total_num_tokens, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device - ) - - # Combined q for FlashMLA: [s_q, h_q, d_qk] - q_all = torch.cat([q_nope, q_pe], dim=-1) # [total_num_tokens, num_qo_heads, head_dim] - - # KV cache (combined) - # flash_mla_sparse_fwd expects: kv [s_kv, h_kv, d_qk] where h_kv=1 for MLA - kv_cache = torch.randn(num_pages, head_dim, dtype=torch.bfloat16, device=device) - ckv_cache = kv_cache[:, :HEAD_DIM_CKV].unsqueeze(1) # [num_pages, 1, ckv] - kpe_cache = kv_cache[:, HEAD_DIM_CKV:].unsqueeze(1) # [num_pages, 1, kpe] - - sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) - - # Generate sparse indices: [total_num_tokens, topk] for reference - sparse_indices = torch.randint( - 0, num_pages, (total_num_tokens, TOPK), dtype=torch.int32, device=device - ) - - # Run reference implementation - print("Running reference implementation...") - ref_result = run( - q_nope, - q_pe, - ckv_cache, - kpe_cache, - sparse_indices, - torch.tensor(sm_scale, dtype=torch.float32, device=device), - ) - ref_output = ref_result["output"] - ref_lse = ref_result["lse"] - - # Run FlashMLA sparse prefill - # flash_mla_sparse_fwd expects: - # q: [s_q, h_q, d_qk] bfloat16 - h_q must be multiple of 64 (SM90) or 128 (SM100+) - # kv: [s_kv, h_kv, d_qk] bfloat16 (h_kv=1 for MLA) - # indices: [s_q, h_kv, topk] int32 - print("Running SGLang FlashMLA sparse prefill...") - try: - kv_for_mla = kv_cache.unsqueeze(1) # [s_kv, 1, d_qk] - - # indices: [s_q, h_kv=1, topk] - indices_for_mla = sparse_indices.unsqueeze(1) # [total_num_tokens, 1, topk] - - # Pad query heads to required multiple (64 or 128) as done in SGLang's dsa_backend.py - need_padding = NUM_QO_HEADS % required_padding != 0 - if need_padding: - assert ( - required_padding % NUM_QO_HEADS == 0 - ), f"required_padding ({required_padding}) must be divisible by NUM_QO_HEADS ({NUM_QO_HEADS})" - q_padded = q_all.new_zeros((total_num_tokens, required_padding, head_dim)) - q_padded[:, :NUM_QO_HEADS, :] = q_all - q_input = q_padded - print(f"Padded q from {NUM_QO_HEADS} to {required_padding} heads") - else: - q_input = q_all - - fi_output_full, fi_max_logits, fi_lse_full = flash_mla_sparse_fwd( - q=q_input, kv=kv_for_mla, indices=indices_for_mla, sm_scale=sm_scale, d_v=HEAD_DIM_CKV - ) - - # Trim output back to original number of heads if padding was applied - if need_padding: - fi_output = fi_output_full[:, :NUM_QO_HEADS, :] - fi_lse = fi_lse_full[:, :NUM_QO_HEADS] - else: - fi_output = fi_output_full - fi_lse = fi_lse_full - - except Exception as e: - print(f"WARNING: FlashMLA sparse fwd failed: {e}") - print("This may be due to API differences - skipping SGLang test") - import traceback - - traceback.print_exc() - return None - - # Compare outputs - print("\nComparing outputs...") - abs_diff, rel_diff, cos_sim = compute_error_metrics(ref_output, fi_output, "output") - - # Check tolerance - allclose = torch.allclose(ref_output.float(), fi_output.float(), atol=atol, rtol=rtol) - - if allclose: - print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") - return True - else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - # Show top error locations - flat = (ref_output.float() - fi_output.float()).abs().flatten() - k = min(5, flat.numel()) - topv, topi = torch.topk(flat, k) - print(f"\nTop-{k} absolute error locations:") - for rank in range(k): - idx = topi[rank].item() - print( - f" idx={idx}: ref={ref_output.flatten()[idx].item():.6e}, " - f"fi={fi_output.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}" - ) - - # Use hit ratio as secondary check - passed = check_hit_ratio(ref_output, fi_output, atol, rtol, required_percent=0.85) - return passed - - -def main(): - """Run comprehensive tests.""" - print("Testing DSA (DeepSeek Sparse Attention) Sparse Prefill Reference Implementation") - print("=" * 70) - print( - f"Constants: h={NUM_QO_HEADS}, ckv={HEAD_DIM_CKV}, kpe={HEAD_DIM_KPE}, ps={PAGE_SIZE}, topk={TOPK}" - ) - print(f"SGLang available: {SGLANG_AVAILABLE}") - print(f"FlashInfer available: {FLASHINFER_AVAILABLE}") - print("=" * 70) - - test_results = [] - - # Basic functionality tests - test_results.append(("output_shape", test_output_shape())) - test_results.append(("padding_handling", test_padding_handling())) - - # Ground truth comparison tests - test_configs = [16, 64, 256] # Small # Medium # Large - - for total_num_tokens in test_configs: - name = f"sglang_tokens{total_num_tokens}" - try: - result = test_correctness_vs_sglang(total_num_tokens) - test_results.append((name, result)) - except Exception as e: - print(f"\n✗ Test {name} crashed: {e}") - import traceback - - traceback.print_exc() - test_results.append((name, False)) - - # Summary - print(f"\n{'='*70}") - print("Test Summary:") - print(f"{'='*70}") - - passed = 0 - skipped = 0 - failed = 0 - - for name, result in test_results: - if result is None: - status = "SKIPPED" - skipped += 1 - elif result: - status = "PASSED" - passed += 1 - else: - status = "FAILED" - failed += 1 - print(f" {name}: {status}") - - print(f"\nTotal: {passed} passed, {failed} failed, {skipped} skipped") - - if failed == 0: - print("\n✓ All tests passed!") - else: - print(f"\n✗ {failed} tests failed") - - -if __name__ == "__main__": - main() diff --git a/flashinfer_trace/tests/references/test_dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps64.py b/flashinfer_trace/tests/references/test_dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps64.py deleted file mode 100644 index 98460ce6..00000000 --- a/flashinfer_trace/tests/references/test_dsa_sparse_prefill_causal_h16_ckv512_kpe64_topk256_ps64.py +++ /dev/null @@ -1,521 +0,0 @@ -""" -Tests for DSA (DeepSeek Sparse Attention) sparse prefill reference implementation. -Page size 64 variant. - -Ground truth sources: -1. SGLang FlashMLA sparse prefill: sgl_kernel.flash_mla.flash_mla_sparse_fwd - -Note: FlashInfer's sparse.py provides BlockSparseAttentionWrapper which uses BSR format, -different from DeepSeek's DSA token-level sparse attention. -""" - -import math -from pathlib import Path - -import numpy as np -import pytest -import torch - -# Ground truth imports with availability checks -try: - from sgl_kernel.flash_mla import flash_mla_sparse_fwd - - SGLANG_AVAILABLE = True -except ImportError: - SGLANG_AVAILABLE = False - -try: - import flashinfer - - FLASHINFER_AVAILABLE = True -except ImportError: - FLASHINFER_AVAILABLE = False - -# Module-level constants (DeepSeek V3/R1 with TP=8) -NUM_QO_HEADS = 16 -HEAD_DIM_CKV = 512 -HEAD_DIM_KPE = 64 -PAGE_SIZE = 64 -TOPK = 256 - -TRACE_ROOT = Path(__file__).resolve().parents[2] - - -@torch.no_grad() -def run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale): - """Reference implementation for DSA sparse prefill attention with page_size=64.""" - total_num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - num_pages, page_size, _ = ckv_cache.shape - topk = sparse_indices.shape[-1] - - # Check constants - assert num_qo_heads == NUM_QO_HEADS - assert head_dim_ckv == HEAD_DIM_CKV - assert head_dim_kpe == HEAD_DIM_KPE - assert page_size == PAGE_SIZE - assert topk == TOPK - - # Check constraints - assert sparse_indices.shape[0] == total_num_tokens - assert sparse_indices.shape[-1] == topk - assert ckv_cache.shape[1] == page_size - - device = q_nope.device - - # Flatten paged KV cache to token-level: [num_pages, page_size, dim] -> [num_pages * page_size, dim] - Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to( - torch.float32 - ) # [total_kv_tokens, head_dim_ckv] - Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to( - torch.float32 - ) # [total_kv_tokens, head_dim_kpe] - - output = torch.zeros( - (total_num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device - ) - lse = torch.full( - (total_num_tokens, num_qo_heads), -float("inf"), dtype=torch.float32, device=device - ) - - for t in range(total_num_tokens): - indices = sparse_indices[t] # [topk] - - # Handle padding: -1 indicates invalid indices - valid_mask = indices != -1 - valid_indices = indices[valid_mask] - - if valid_indices.numel() == 0: - output[t].zero_() - continue - - # For page_size=64, indices encode (page_idx * 64 + offset) - tok_idx = valid_indices.to(torch.long) - - Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv] - Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe] - qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv] - qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe] - - # Compute attention logits - logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid] - logits_scaled = logits * sm_scale - - # Compute 2-base LSE - lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - # Compute attention output - attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid] - out = attn @ Kc # [num_qo_heads, head_dim_ckv] - output[t] = out.to(torch.bfloat16) - - return {"output": output, "lse": lse} - - -def generate_random_inputs( - total_num_tokens, - num_qo_heads=NUM_QO_HEADS, - head_dim_ckv=HEAD_DIM_CKV, - head_dim_kpe=HEAD_DIM_KPE, - topk=TOPK, - device="cuda", -): - """Generate random inputs for DSA sparse prefill attention testing with page_size=64.""" - # Generate KV cache with enough pages - # Each page holds 64 tokens, so we need more pages to hold enough tokens - total_kv_tokens = max(total_num_tokens * 4, 2048) # Ensure enough KV tokens - num_pages = (total_kv_tokens + PAGE_SIZE - 1) // PAGE_SIZE - - # Generate sparse indices (top-K selection for each token) - # Indices encode global token position: page_idx * 64 + offset - total_tokens_in_cache = num_pages * PAGE_SIZE - sparse_indices = torch.randint( - 0, total_tokens_in_cache, (total_num_tokens, topk), dtype=torch.int32, device=device - ) - - # Generate query tensors - q_nope = torch.randn( - total_num_tokens, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn( - total_num_tokens, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device - ) - - # Generate compressed KV and positional caches with page_size=64 - ckv_cache = torch.randn(num_pages, PAGE_SIZE, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, PAGE_SIZE, head_dim_kpe, dtype=torch.bfloat16, device=device) - - # Generate softmax scale - # MLA uses head dimension before matrix absorption (128 + 64 = 192) - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) - - return { - "q_nope": q_nope, - "q_pe": q_pe, - "ckv_cache": ckv_cache, - "kpe_cache": kpe_cache, - "sparse_indices": sparse_indices, - "sm_scale": torch.tensor(sm_scale, dtype=torch.float32, device=device), - "num_pages": num_pages, - } - - -def compute_error_metrics(ref, gt, name="output"): - """Compute and print detailed error metrics.""" - ref_f32 = ref.float() - gt_f32 = gt.float() - - abs_diff = torch.abs(ref_f32 - gt_f32) - rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\n{name} comparison:") - print(f" Max absolute difference: {max_abs_diff:.6e}") - print(f" Max relative difference: {max_rel_diff:.6e}") - print(f" Mean absolute difference: {mean_abs_diff:.6e}") - print(f" Mean relative difference: {mean_rel_diff:.6e}") - - # Cosine similarity and MSE - cos_sim = torch.nn.functional.cosine_similarity( - ref_f32.flatten(), gt_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_f32 - gt_f32) ** 2).item() - print(f" Cosine similarity: {cos_sim:.6f}") - print(f" MSE: {mse:.6e}") - - return abs_diff, rel_diff, cos_sim - - -def check_hit_ratio(ref, gt, atol, rtol, required_percent=0.85): - """Check if hit ratio meets threshold (for quantized kernels).""" - ref_f32 = ref.float() - gt_f32 = gt.float() - - left = (ref_f32 - gt_f32).abs() - right = atol + rtol * gt_f32.abs() - ok = left <= right - hit_ratio = ok.float().mean().item() - - print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {required_percent * 100:.2f}%)") - return hit_ratio >= required_percent - - -def test_output_shape(total_num_tokens=64, topk=TOPK): - """Test that reference produces correct output shapes.""" - print(f"\n{'='*60}") - print( - f"Testing DSA prefill ps64 output shape: total_num_tokens={total_num_tokens}, topk={topk}" - ) - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - inputs = generate_random_inputs(total_num_tokens, topk=topk, device=device) - - result = run( - inputs["q_nope"], - inputs["q_pe"], - inputs["ckv_cache"], - inputs["kpe_cache"], - inputs["sparse_indices"], - inputs["sm_scale"], - ) - - output = result["output"] - lse = result["lse"] - - expected_output_shape = (total_num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV) - expected_lse_shape = (total_num_tokens, NUM_QO_HEADS) - - output_shape_correct = output.shape == expected_output_shape - lse_shape_correct = lse.shape == expected_lse_shape - - print(f"Output shape: {output.shape} (expected: {expected_output_shape})") - print(f"LSE shape: {lse.shape} (expected: {expected_lse_shape})") - - if output_shape_correct and lse_shape_correct: - print("PASSED: Output shapes are correct") - return True - else: - print("FAILED: Output shapes are incorrect") - return False - - -def test_padding_handling(total_num_tokens=64, topk=TOPK): - """Test that padding (-1 indices) are handled correctly.""" - print(f"\n{'='*60}") - print(f"Testing DSA prefill ps64 padding handling") - print(f"{'='*60}") - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("WARNING: CUDA not available, using CPU") - - num_pages = 64 # 64 pages * 64 tokens = 4096 total tokens - - q_nope = torch.randn( - total_num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn( - total_num_tokens, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device - ) - ckv_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - sm_scale = torch.tensor(1.0 / np.sqrt(128 + HEAD_DIM_KPE), dtype=torch.float32, device=device) - - # Create sparse indices with varying amounts of padding per token - sparse_indices = torch.full((total_num_tokens, topk), -1, dtype=torch.int32, device=device) - total_tokens_in_cache = num_pages * PAGE_SIZE - - for t in range(total_num_tokens): - # Vary the number of valid indices - valid_count = (t % 4 + 1) * (topk // 4) # 25%, 50%, 75%, 100% - valid_count = min(valid_count, topk) - sparse_indices[t, :valid_count] = torch.randint( - 0, total_tokens_in_cache, (valid_count,), dtype=torch.int32, device=device - ) - - result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - output = result["output"] - lse = result["lse"] - - # Verify outputs are valid - output_valid = not torch.isnan(output).any() and not torch.isinf(output).any() - lse_valid = not torch.isnan(lse).any() and not torch.isinf(lse[lse > -float("inf")]).any() - - print(f"Output valid (no nan/inf): {output_valid}") - print(f"LSE valid: {lse_valid}") - - if output_valid and lse_valid: - print("PASSED: Padding handled correctly") - return True - else: - print("FAILED: Padding handling issue") - return False - - -def test_correctness_vs_sglang(total_num_tokens=64, atol=1e-2, rtol=5e-2): - """ - Test correctness against SGLang FlashMLA sparse prefill kernel. - - NOTE: This test requires SGLang sgl_kernel to be installed. - If not available, the test will be skipped. - - NOTE: FlashMLA sparse kernel operates at token-level (page_size=1). - For page_size=64, we flatten to token-level for comparison. - """ - print(f"\n{'='*60}") - print(f"Testing DSA prefill ps64 correctness against SGLang FlashMLA") - print(f"total_num_tokens={total_num_tokens}") - print(f"{'='*60}") - - if not SGLANG_AVAILABLE: - print("SKIPPED: SGLang/sgl_kernel not available") - return None - - device = "cuda" if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("SKIPPED: CUDA not available") - return None - - torch.manual_seed(42) - - # Test parameters - num_pages = 32 # 32 pages * 64 tokens/page = 2048 total KV tokens - total_kv_tokens = num_pages * PAGE_SIZE - head_dim = HEAD_DIM_CKV + HEAD_DIM_KPE # Combined head dim = 576 - - # Determine required head padding based on GPU architecture - # FlashMLA kernel requires h_q to be multiple of 64 (Hopper SM90) or 128 (Blackwell SM100+) - device_sm_major = torch.cuda.get_device_properties(device).major - required_padding = 128 if device_sm_major >= 10 else 64 - print(f"GPU SM major: {device_sm_major}, required head padding: {required_padding}") - - # Generate query tensors - q_nope = torch.randn( - total_num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device - ) - q_pe = torch.randn( - total_num_tokens, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device - ) - - # Combined q for FlashMLA: [s_q, h_q, d_qk] - q_all = torch.cat([q_nope, q_pe], dim=-1) # [total_num_tokens, num_qo_heads, head_dim] - - # KV cache with page_size=64 - # Shape: [num_pages, page_size, head_dim] - kv_cache_paged = torch.randn( - num_pages, PAGE_SIZE, head_dim, dtype=torch.bfloat16, device=device - ) - ckv_cache = kv_cache_paged[:, :, :HEAD_DIM_CKV] # [num_pages, 64, ckv] - kpe_cache = kv_cache_paged[:, :, HEAD_DIM_CKV:] # [num_pages, 64, kpe] - - # Flatten for FlashMLA (token-level) - kv_cache_flat = kv_cache_paged.reshape(total_kv_tokens, head_dim) # [total_kv_tokens, head_dim] - - sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) - - # Generate sparse indices as global token indices: [total_num_tokens, topk] - # Each index is in range [0, total_kv_tokens) - sparse_indices = torch.randint( - 0, total_kv_tokens, (total_num_tokens, TOPK), dtype=torch.int32, device=device - ) - - # Run reference implementation with page_size=64 - print("Running reference implementation (page_size=64)...") - ref_result = run( - q_nope, - q_pe, - ckv_cache, - kpe_cache, - sparse_indices, - torch.tensor(sm_scale, dtype=torch.float32, device=device), - ) - ref_output = ref_result["output"] - ref_lse = ref_result["lse"] - - # Run FlashMLA sparse prefill (token-level) - # flash_mla_sparse_fwd expects: - # q: [s_q, h_q, d_qk] bfloat16 - h_q must be multiple of 64 (SM90) or 128 (SM100+) - # kv: [s_kv, h_kv, d_qk] bfloat16 (h_kv=1 for MLA) - # indices: [s_q, h_kv, topk] int32 - print("Running SGLang FlashMLA sparse prefill (token-level)...") - try: - kv_for_mla = kv_cache_flat.unsqueeze(1) # [s_kv, 1, d_qk] - - # indices: [s_q, h_kv=1, topk] - indices_for_mla = sparse_indices.unsqueeze(1) # [total_num_tokens, 1, topk] - - # Pad query heads to required multiple (64 or 128) as done in SGLang's dsa_backend.py - need_padding = NUM_QO_HEADS % required_padding != 0 - if need_padding: - assert ( - required_padding % NUM_QO_HEADS == 0 - ), f"required_padding ({required_padding}) must be divisible by NUM_QO_HEADS ({NUM_QO_HEADS})" - q_padded = q_all.new_zeros((total_num_tokens, required_padding, head_dim)) - q_padded[:, :NUM_QO_HEADS, :] = q_all - q_input = q_padded - print(f"Padded q from {NUM_QO_HEADS} to {required_padding} heads") - else: - q_input = q_all - - fi_output_full, fi_max_logits, fi_lse_full = flash_mla_sparse_fwd( - q=q_input, kv=kv_for_mla, indices=indices_for_mla, sm_scale=sm_scale, d_v=HEAD_DIM_CKV - ) - - # Trim output back to original number of heads if padding was applied - if need_padding: - fi_output = fi_output_full[:, :NUM_QO_HEADS, :] - fi_lse = fi_lse_full[:, :NUM_QO_HEADS] - else: - fi_output = fi_output_full - fi_lse = fi_lse_full - - except Exception as e: - print(f"WARNING: FlashMLA sparse fwd failed: {e}") - print("This may be due to API differences - skipping SGLang test") - import traceback - - traceback.print_exc() - return None - - # Compare outputs - print("\nComparing outputs...") - abs_diff, rel_diff, cos_sim = compute_error_metrics(ref_output, fi_output, "output") - - # Check tolerance - allclose = torch.allclose(ref_output.float(), fi_output.float(), atol=atol, rtol=rtol) - - if allclose: - print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") - return True - else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - # Show top error locations - flat = (ref_output.float() - fi_output.float()).abs().flatten() - k = min(5, flat.numel()) - topv, topi = torch.topk(flat, k) - print(f"\nTop-{k} absolute error locations:") - for rank in range(k): - idx = topi[rank].item() - print( - f" idx={idx}: ref={ref_output.flatten()[idx].item():.6e}, " - f"fi={fi_output.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}" - ) - - # Use hit ratio as secondary check - passed = check_hit_ratio(ref_output, fi_output, atol, rtol, required_percent=0.85) - return passed - - -def main(): - """Run comprehensive tests.""" - print("Testing DSA (DeepSeek Sparse Attention) Sparse Prefill Reference Implementation") - print("Page Size 64 Variant") - print("=" * 70) - print( - f"Constants: h={NUM_QO_HEADS}, ckv={HEAD_DIM_CKV}, kpe={HEAD_DIM_KPE}, ps={PAGE_SIZE}, topk={TOPK}" - ) - print(f"SGLang available: {SGLANG_AVAILABLE}") - print(f"FlashInfer available: {FLASHINFER_AVAILABLE}") - print("=" * 70) - - test_results = [] - - # Basic functionality tests - test_results.append(("output_shape", test_output_shape())) - test_results.append(("padding_handling", test_padding_handling())) - - # Ground truth comparison tests - test_configs = [16, 64, 256] # Small # Medium # Large - - for total_num_tokens in test_configs: - name = f"sglang_tokens{total_num_tokens}" - try: - result = test_correctness_vs_sglang(total_num_tokens) - test_results.append((name, result)) - except Exception as e: - print(f"\n✗ Test {name} crashed: {e}") - import traceback - - traceback.print_exc() - test_results.append((name, False)) - - # Summary - print(f"\n{'='*70}") - print("Test Summary:") - print(f"{'='*70}") - - passed = 0 - skipped = 0 - failed = 0 - - for name, result in test_results: - if result is None: - status = "SKIPPED" - skipped += 1 - elif result: - status = "PASSED" - passed += 1 - else: - status = "FAILED" - failed += 1 - print(f" {name}: {status}") - - print(f"\nTotal: {passed} passed, {failed} failed, {skipped} skipped") - - if failed == 0: - print("\n✓ All tests passed!") - else: - print(f"\n✗ {failed} tests failed") - - -if __name__ == "__main__": - main() diff --git a/flashinfer_trace/tests/references/test_dsa_vs_definition_reference.py b/flashinfer_trace/tests/references/test_dsa_vs_definition_reference.py index 216e0109..b60d6afd 100644 --- a/flashinfer_trace/tests/references/test_dsa_vs_definition_reference.py +++ b/flashinfer_trace/tests/references/test_dsa_vs_definition_reference.py @@ -79,7 +79,7 @@ def test_trtllm_mla_sparse_vs_definition_reference(): device = "cuda" # Load definition and build reference - definition = load_definition("dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64") + definition = load_definition("dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64") reference = build_reference_runnable(definition) print(f"\nLoaded definition: {definition.name}") @@ -107,8 +107,7 @@ def test_trtllm_mla_sparse_vs_definition_reference(): # Run definition reference print("\nRunning definition reference...") - ref_result = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - ref_output = ref_result["output"] + ref_output, ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) # Prepare FlashInfer inputs (trtllm-gen format) query = torch.cat([q_nope, q_pe], dim=-1).unsqueeze(1) # [batch, 1, heads, 576] @@ -250,7 +249,7 @@ def test_topk_indexer_fp8_vs_definition_reference(): # Run definition reference print("\nRunning definition reference...") ref_result = reference(q_index_fp8, k_index_cache_fp8, weights, seq_lens, block_table) - ref_indices = ref_result["topk_indices"] + ref_indices = ref_result # Run deep_gemm to compute FP8 scores (deep_gemm expects uint8) # deep_gemm expects q shape: [batch, next_n, heads, head_dim] @@ -333,7 +332,7 @@ def test_trtllm_mla_sparse_various_configs(batch_size, max_seq_len): torch.manual_seed(42) device = "cuda" - definition = load_definition("dsa_sparse_decode_h16_ckv512_kpe64_topk256_ps64") + definition = load_definition("dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64") reference = build_reference_runnable(definition) max_num_pages = (max_seq_len + PAGE_SIZE - 1) // PAGE_SIZE @@ -353,8 +352,7 @@ def test_trtllm_mla_sparse_various_configs(batch_size, max_seq_len): 1.0 / math.sqrt(QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM), dtype=torch.float32, device=device ) - ref_result = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - ref_output = ref_result["output"] + ref_output, ref_lse = reference(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) query = torch.cat([q_nope, q_pe], dim=-1).unsqueeze(1) kv_cache = torch.cat([ckv_cache, kpe_cache], dim=-1) @@ -432,7 +430,7 @@ def test_topk_indexer_fp8_various_configs(batch_size, max_seq_len): page_offset += num_pages_for_seq ref_result = reference(q_index_fp8, k_index_cache_fp8, weights, seq_lens, block_table) - ref_indices = ref_result["topk_indices"] + ref_indices = ref_result q_index_fp8_4d = q_index_fp8.unsqueeze(1) k_index_cache_uint8 = k_index_cache_fp8.view(torch.uint8) diff --git a/flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py b/flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py index 6e2df869..5b40369e 100644 --- a/flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py +++ b/flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py @@ -65,7 +65,7 @@ def run_kernel(q, k, v, state, A_log, a, dt_bias, b, scale): b=b, scale=scale, output=output, - use_qk_l2norm=True, + use_qk_l2norm=False, ) return out, new_state @@ -145,8 +145,7 @@ def test_correctness(batch_size=4, atol=5e-3, rtol=5e-3): inputs["b"].clone(), inputs["scale"], ) - ref_output = ref_result["output"] - ref_new_state = ref_result["new_state"] + ref_output, ref_new_state = ref_result # Run kernel print("Running FlashInfer kernel...") @@ -215,8 +214,7 @@ def test_gdn_decode_k_last(batch_size: int): inputs["b"].clone(), inputs["scale"], ) - ref_output = ref_result["output"] - ref_new_state = ref_result["new_state"] + ref_output, ref_new_state = ref_result # Run kernel kernel_output, kernel_new_state = run_kernel( @@ -231,7 +229,7 @@ def test_gdn_decode_k_last(batch_size: int): inputs["scale"], ) - atol, rtol = 5e-3, 5e-3 + atol, rtol = 1e-2, 1e-2 torch.testing.assert_close( kernel_output, diff --git a/flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py b/flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py index 6fdd9f58..f45dbeca 100644 --- a/flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py +++ b/flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py @@ -44,8 +44,9 @@ def get_cuda_capability(): return torch.cuda.get_device_capability(0) -requires_sm90 = pytest.mark.skipif( - get_cuda_capability()[0] < 9, reason="GDN prefill kernel requires SM90 (Hopper) or later" +requires_sm90_only = pytest.mark.skipif( + get_cuda_capability()[0] != 9, + reason="GDN prefill kernel only supports SM90 (Hopper), not SM80 or SM100+", ) requires_cuda = pytest.mark.skipif( @@ -71,7 +72,7 @@ def compute_gates(A_log, a, dt_bias, b): @requires_cuda -@requires_sm90 +@requires_sm90_only @pytest.mark.parametrize("batch_size", [1, 2, 4]) @pytest.mark.parametrize("seq_len", [16, 64, 128]) def test_gdn_prefill_correctness(batch_size: int, seq_len: int): @@ -108,8 +109,7 @@ def test_gdn_prefill_correctness(batch_size: int, seq_len: int): # Reference from definition ref_result = reference_gdn_prefill(q, k, v, None, A_log, a, dt_bias, b, cu_seqlens, scale) - ref_output = ref_result["output"] - ref_new_state = ref_result["new_state"] + ref_output, ref_new_state = ref_result # FlashInfer uses pre-computed g/beta g, beta = compute_gates(A_log, a, dt_bias, b) @@ -143,7 +143,7 @@ def test_gdn_prefill_correctness(batch_size: int, seq_len: int): @requires_cuda -@requires_sm90 +@requires_sm90_only def test_gdn_prefill_with_initial_state(): """Test GDN prefill kernel with non-zero initial state.""" from flashinfer.gdn_prefill import chunk_gated_delta_rule @@ -187,8 +187,7 @@ def test_gdn_prefill_with_initial_state(): scale = 1.0 / math.sqrt(head_size) ref_result = reference_gdn_prefill(q, k, v, state, A_log, a, dt_bias, b, cu_seqlens, scale) - ref_output = ref_result["output"] - ref_new_state = ref_result["new_state"] + ref_output, ref_new_state = ref_result g, beta = compute_gates(A_log, a, dt_bias, b) fi_output, fi_new_state = chunk_gated_delta_rule( @@ -219,7 +218,7 @@ def test_gdn_prefill_with_initial_state(): @requires_cuda -@requires_sm90 +@requires_sm90_only def test_gdn_prefill_variable_seqlen(): """Test GDN prefill kernel with variable sequence lengths.""" from flashinfer.gdn_prefill import chunk_gated_delta_rule @@ -255,8 +254,7 @@ def test_gdn_prefill_variable_seqlen(): scale = 1.0 / math.sqrt(head_size) ref_result = reference_gdn_prefill(q, k, v, None, A_log, a, dt_bias, b, cu_seqlens, scale) - ref_output = ref_result["output"] - ref_new_state = ref_result["new_state"] + ref_output, ref_new_state = ref_result g, beta = compute_gates(A_log, a, dt_bias, b) fi_output, fi_new_state = chunk_gated_delta_rule(