Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""
Tests for DSA (DeepSeek Sparse Attention) sparse attention reference implementation.
Test DSA (DeepSeek Sparse Attention) sparse attention reference implementation.

This test validates that the reference implementation from the definition
produces correct output shapes and handles padding correctly.

Ground truth comparison tests are in test_dsa_vs_definition_reference.py
which tests against FlashInfer's trtllm_batch_decode_with_kv_cache_mla.
"""

import math
from pathlib import Path

import flashinfer
import numpy as np
import pytest
import torch
from test_utils import compare_tensors, get_reference_run, print_comparison_metrics

# Load reference implementation from definition
run = get_reference_run("dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1")

# Module-level constants (DeepSeek V3/R1 with TP=8)
NUM_QO_HEADS = 16
Expand All @@ -19,97 +24,27 @@
PAGE_SIZE = 1
TOPK = 256

TRACE_ROOT = Path(__file__).resolve().parents[2]


@torch.no_grad()
def run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale):
"""Reference implementation for DSA sparse attention."""
num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape
head_dim_kpe = q_pe.shape[-1]
page_size = ckv_cache.shape[1]
topk = sparse_indices.shape[-1]

# Check constants
assert num_qo_heads == NUM_QO_HEADS
assert head_dim_ckv == HEAD_DIM_CKV
assert head_dim_kpe == HEAD_DIM_KPE
assert page_size == PAGE_SIZE
assert topk == TOPK

# Check constraints
assert sparse_indices.shape[0] == num_tokens
assert sparse_indices.shape[-1] == topk
assert ckv_cache.shape[1] == page_size

device = q_nope.device

# Squeeze page dimension (page_size=1)
Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]
Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]

output = torch.zeros(
(num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device
)
lse = torch.full((num_tokens, num_qo_heads), -float("inf"), dtype=torch.float32, device=device)

for t in range(num_tokens):
indices = sparse_indices[t] # [topk]

# Handle padding: -1 indicates invalid indices
valid_mask = indices != -1
valid_indices = indices[valid_mask]

if valid_indices.numel() == 0:
output[t].zero_()
continue

tok_idx = valid_indices.to(torch.long)

Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv]
Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe]
qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv]
qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe]

# Compute attention logits
logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid]
logits_scaled = logits * sm_scale

# Compute 2-base LSE
lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0)

# Compute attention output
attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid]
out = attn @ Kc # [num_qo_heads, head_dim_ckv]
output[t] = out.to(torch.bfloat16)

return {"output": output, "lse": lse}


def generate_random_inputs(
num_tokens,
num_qo_heads=NUM_QO_HEADS,
head_dim_ckv=HEAD_DIM_CKV,
head_dim_kpe=HEAD_DIM_KPE,
topk=TOPK,
device="cuda",
):
def generate_random_inputs(num_tokens, topk=TOPK, device="cuda"):
"""Generate random inputs for DSA sparse attention testing."""
num_pages = max(num_tokens * 2, 1024)
total_kv_tokens = max(num_tokens * 4, 2048)
num_pages = (total_kv_tokens + PAGE_SIZE - 1) // PAGE_SIZE

total_tokens_in_cache = num_pages * PAGE_SIZE

sparse_indices = torch.randint(
0, num_pages, (num_tokens, topk), dtype=torch.int32, device=device
0, total_tokens_in_cache, (num_tokens, topk), dtype=torch.int32, device=device
)

q_nope = torch.randn(
num_tokens, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device
num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device
)
q_pe = torch.randn(num_tokens, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device)
q_pe = torch.randn(num_tokens, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device)

ckv_cache = torch.randn(num_pages, 1, head_dim_ckv, dtype=torch.bfloat16, device=device)
kpe_cache = torch.randn(num_pages, 1, head_dim_kpe, dtype=torch.bfloat16, device=device)
ckv_cache = torch.randn(num_pages, 1, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device)
kpe_cache = torch.randn(num_pages, 1, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device)

sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe)
sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE)

return {
"q_nope": q_nope,
Expand All @@ -122,6 +57,70 @@ def generate_random_inputs(
}


def test_correctness(num_tokens=64, topk=TOPK, atol=1e-2, rtol=5e-2):
"""Test correctness of DSA sparse attention reference implementation against FlashInfer."""
print(f"\n{'='*60}")
print(f"Testing DSA Sparse Attention num_tokens={num_tokens}, topk={topk}")
print(f"{'='*60}")

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("WARNING: CUDA not available, skipping test")
return True

inputs = generate_random_inputs(num_tokens, topk=topk, device=device)

print("Running reference implementation from definition...")
ref_o, ref_lse = run(
inputs["q_nope"],
inputs["q_pe"],
inputs["ckv_cache"],
inputs["kpe_cache"],
inputs["sparse_indices"],
inputs["sm_scale"],
)

# Prepare FlashInfer inputs (trtllm-gen format)
# Query: concatenate q_nope and q_pe, add seq_len dim
query = torch.cat([inputs["q_nope"], inputs["q_pe"]], dim=-1).unsqueeze(1)
# KV cache: concatenate ckv and kpe caches
kv_cache = torch.cat([inputs["ckv_cache"], inputs["kpe_cache"]], dim=-1)
# Block tables: add seq_len dim to sparse_indices
block_tables = inputs["sparse_indices"].unsqueeze(1)
workspace = torch.zeros(16 * 1024 * 1024, dtype=torch.uint8, device=device)
total_tokens = inputs["num_pages"] * PAGE_SIZE
seq_lens = torch.full((num_tokens,), total_tokens, dtype=torch.int32, device=device)

print("Running FlashInfer...")
fi_output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace,
qk_nope_head_dim=128, # QK_NOPE_HEAD_DIM
kv_lora_rank=HEAD_DIM_CKV,
qk_rope_head_dim=HEAD_DIM_KPE,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=total_tokens,
sparse_mla_top_k=topk,
bmm1_scale=inputs["sm_scale"].item(),
)
fi_output = fi_output.squeeze(1) # Remove seq_len dim

print("\nComparing outputs...")
output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol)
print_comparison_metrics(output_metrics, tensor_name="Output tensor")

all_close = output_metrics.all_close

if all_close:
print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})")
else:
print(f"\n✗ FAILED: Outputs differ beyond tolerance")

return all_close


def test_output_shape(num_tokens=64, topk=TOPK):
"""Test that reference produces correct output shapes."""
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -136,8 +135,7 @@ def test_output_shape(num_tokens=64, topk=TOPK):
inputs["sm_scale"],
)

output = result["output"]
lse = result["lse"]
output, lse = result

assert output.shape == (num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV)
assert lse.shape == (num_tokens, NUM_QO_HEADS)
Expand Down Expand Up @@ -167,8 +165,7 @@ def test_padding_handling(num_tokens=64, topk=TOPK):
)

result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale)
output = result["output"]
lse = result["lse"]
output, lse = result

# Verify outputs are valid
assert not torch.isnan(output).any()
Expand All @@ -177,12 +174,16 @@ def test_padding_handling(num_tokens=64, topk=TOPK):


if __name__ == "__main__":
print("Testing DSA Sparse Attention Reference (page_size=1)")
print("Testing DSA Sparse Attention Reference (from definition)")
print(
f"Constants: h={NUM_QO_HEADS}, ckv={HEAD_DIM_CKV}, kpe={HEAD_DIM_KPE}, ps={PAGE_SIZE}, topk={TOPK}"
)
print("=" * 70)

test_configs = [(16, TOPK), (32, TOPK), (64, TOPK), (128, TOPK)]
passed = sum(1 for cfg in test_configs if test_correctness(*cfg))
print(f"\n{'='*60}\nCorrectness: {passed}/{len(test_configs)} tests passed\n{'='*60}")

test_output_shape()
print("test_output_shape: PASSED")

Expand Down
Loading