diff --git a/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.py b/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.py index 773f4911..eca3273d 100644 --- a/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.py +++ b/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1.py @@ -1,16 +1,21 @@ """ -Tests for DSA (DeepSeek Sparse Attention) sparse attention reference implementation. +Test DSA (DeepSeek Sparse Attention) sparse attention reference implementation. + +This test validates that the reference implementation from the definition +produces correct output shapes and handles padding correctly. Ground truth comparison tests are in test_dsa_vs_definition_reference.py which tests against FlashInfer's trtllm_batch_decode_with_kv_cache_mla. """ -import math -from pathlib import Path - +import flashinfer import numpy as np import pytest import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics + +# Load reference implementation from definition +run = get_reference_run("dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps1") # Module-level constants (DeepSeek V3/R1 with TP=8) NUM_QO_HEADS = 16 @@ -19,97 +24,27 @@ PAGE_SIZE = 1 TOPK = 256 -TRACE_ROOT = Path(__file__).resolve().parents[2] - - -@torch.no_grad() -def run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale): - """Reference implementation for DSA sparse attention.""" - num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - page_size = ckv_cache.shape[1] - topk = sparse_indices.shape[-1] - - # Check constants - assert num_qo_heads == NUM_QO_HEADS - assert head_dim_ckv == HEAD_DIM_CKV - assert head_dim_kpe == HEAD_DIM_KPE - assert page_size == PAGE_SIZE - assert topk == TOPK - - # Check constraints - assert sparse_indices.shape[0] == num_tokens - assert sparse_indices.shape[-1] == topk - assert ckv_cache.shape[1] == page_size - - device = q_nope.device - - # Squeeze page dimension (page_size=1) - Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv] - Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe] - - output = torch.zeros( - (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device - ) - lse = torch.full((num_tokens, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - for t in range(num_tokens): - indices = sparse_indices[t] # [topk] - - # Handle padding: -1 indicates invalid indices - valid_mask = indices != -1 - valid_indices = indices[valid_mask] - - if valid_indices.numel() == 0: - output[t].zero_() - continue - - tok_idx = valid_indices.to(torch.long) - - Kc = Kc_all[tok_idx] # [num_valid, head_dim_ckv] - Kp = Kp_all[tok_idx] # [num_valid, head_dim_kpe] - qn = q_nope[t].to(torch.float32) # [num_qo_heads, head_dim_ckv] - qp = q_pe[t].to(torch.float32) # [num_qo_heads, head_dim_kpe] - - # Compute attention logits - logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, num_valid] - logits_scaled = logits * sm_scale - # Compute 2-base LSE - lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - # Compute attention output - attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, num_valid] - out = attn @ Kc # [num_qo_heads, head_dim_ckv] - output[t] = out.to(torch.bfloat16) - - return {"output": output, "lse": lse} - - -def generate_random_inputs( - num_tokens, - num_qo_heads=NUM_QO_HEADS, - head_dim_ckv=HEAD_DIM_CKV, - head_dim_kpe=HEAD_DIM_KPE, - topk=TOPK, - device="cuda", -): +def generate_random_inputs(num_tokens, topk=TOPK, device="cuda"): """Generate random inputs for DSA sparse attention testing.""" - num_pages = max(num_tokens * 2, 1024) + total_kv_tokens = max(num_tokens * 4, 2048) + num_pages = (total_kv_tokens + PAGE_SIZE - 1) // PAGE_SIZE + + total_tokens_in_cache = num_pages * PAGE_SIZE sparse_indices = torch.randint( - 0, num_pages, (num_tokens, topk), dtype=torch.int32, device=device + 0, total_tokens_in_cache, (num_tokens, topk), dtype=torch.int32, device=device ) q_nope = torch.randn( - num_tokens, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device + num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device ) - q_pe = torch.randn(num_tokens, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) + q_pe = torch.randn(num_tokens, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - ckv_cache = torch.randn(num_pages, 1, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, 1, head_dim_kpe, dtype=torch.bfloat16, device=device) + ckv_cache = torch.randn(num_pages, 1, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(num_pages, 1, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) + sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) return { "q_nope": q_nope, @@ -122,6 +57,70 @@ def generate_random_inputs( } +def test_correctness(num_tokens=64, topk=TOPK, atol=1e-2, rtol=5e-2): + """Test correctness of DSA sparse attention reference implementation against FlashInfer.""" + print(f"\n{'='*60}") + print(f"Testing DSA Sparse Attention num_tokens={num_tokens}, topk={topk}") + print(f"{'='*60}") + + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cpu": + print("WARNING: CUDA not available, skipping test") + return True + + inputs = generate_random_inputs(num_tokens, topk=topk, device=device) + + print("Running reference implementation from definition...") + ref_o, ref_lse = run( + inputs["q_nope"], + inputs["q_pe"], + inputs["ckv_cache"], + inputs["kpe_cache"], + inputs["sparse_indices"], + inputs["sm_scale"], + ) + + # Prepare FlashInfer inputs (trtllm-gen format) + # Query: concatenate q_nope and q_pe, add seq_len dim + query = torch.cat([inputs["q_nope"], inputs["q_pe"]], dim=-1).unsqueeze(1) + # KV cache: concatenate ckv and kpe caches + kv_cache = torch.cat([inputs["ckv_cache"], inputs["kpe_cache"]], dim=-1) + # Block tables: add seq_len dim to sparse_indices + block_tables = inputs["sparse_indices"].unsqueeze(1) + workspace = torch.zeros(16 * 1024 * 1024, dtype=torch.uint8, device=device) + total_tokens = inputs["num_pages"] * PAGE_SIZE + seq_lens = torch.full((num_tokens,), total_tokens, dtype=torch.int32, device=device) + + print("Running FlashInfer...") + fi_output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace, + qk_nope_head_dim=128, # QK_NOPE_HEAD_DIM + kv_lora_rank=HEAD_DIM_CKV, + qk_rope_head_dim=HEAD_DIM_KPE, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=total_tokens, + sparse_mla_top_k=topk, + bmm1_scale=inputs["sm_scale"].item(), + ) + fi_output = fi_output.squeeze(1) # Remove seq_len dim + + print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") + + all_close = output_metrics.all_close + + if all_close: + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") + else: + print(f"\n✗ FAILED: Outputs differ beyond tolerance") + + return all_close + + def test_output_shape(num_tokens=64, topk=TOPK): """Test that reference produces correct output shapes.""" device = "cuda" if torch.cuda.is_available() else "cpu" @@ -136,8 +135,7 @@ def test_output_shape(num_tokens=64, topk=TOPK): inputs["sm_scale"], ) - output = result["output"] - lse = result["lse"] + output, lse = result assert output.shape == (num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV) assert lse.shape == (num_tokens, NUM_QO_HEADS) @@ -167,8 +165,7 @@ def test_padding_handling(num_tokens=64, topk=TOPK): ) result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - output = result["output"] - lse = result["lse"] + output, lse = result # Verify outputs are valid assert not torch.isnan(output).any() @@ -177,12 +174,16 @@ def test_padding_handling(num_tokens=64, topk=TOPK): if __name__ == "__main__": - print("Testing DSA Sparse Attention Reference (page_size=1)") + print("Testing DSA Sparse Attention Reference (from definition)") print( f"Constants: h={NUM_QO_HEADS}, ckv={HEAD_DIM_CKV}, kpe={HEAD_DIM_KPE}, ps={PAGE_SIZE}, topk={TOPK}" ) print("=" * 70) + test_configs = [(16, TOPK), (32, TOPK), (64, TOPK), (128, TOPK)] + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nCorrectness: {passed}/{len(test_configs)} tests passed\n{'='*60}") + test_output_shape() print("test_output_shape: PASSED") diff --git a/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.py b/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.py index 685a6916..1d935aeb 100644 --- a/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.py +++ b/flashinfer_trace/tests/references/test_dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64.py @@ -1,17 +1,18 @@ """ -Tests for DSA (DeepSeek Sparse Attention) sparse attention reference implementation. -Page size 64 variant. +Test DSA sparse attention h16_ckv512_kpe64_topk256_ps64 reference implementation. -Ground truth comparison tests are in test_dsa_vs_definition_reference.py -which tests against FlashInfer's trtllm_batch_decode_with_kv_cache_mla. +This test validates that the reference implementation from the definition +produces correct output shapes and handles padding correctly. """ -import math -from pathlib import Path - +import flashinfer import numpy as np import pytest import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics + +# Load reference implementation from definition +run = get_reference_run("dsa_sparse_attention_h16_ckv512_kpe64_topk256_ps64") # Module-level constants (DeepSeek V3/R1 with TP=8) NUM_QO_HEADS = 16 @@ -20,95 +21,24 @@ PAGE_SIZE = 64 TOPK = 256 -TRACE_ROOT = Path(__file__).resolve().parents[2] - - -@torch.no_grad() -def run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale): - """Reference implementation for DSA sparse attention with page_size=64.""" - num_tokens, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - num_pages, page_size, _ = ckv_cache.shape - topk = sparse_indices.shape[-1] - - # Check constants - assert num_qo_heads == NUM_QO_HEADS - assert head_dim_ckv == HEAD_DIM_CKV - assert head_dim_kpe == HEAD_DIM_KPE - assert page_size == PAGE_SIZE - assert topk == TOPK - - # Check constraints - assert sparse_indices.shape[0] == num_tokens - assert sparse_indices.shape[-1] == topk - assert ckv_cache.shape[1] == page_size - - device = q_nope.device - - # Flatten paged KV cache to token-level - Kc_all = ckv_cache.reshape(-1, head_dim_ckv).to(torch.float32) - Kp_all = kpe_cache.reshape(-1, head_dim_kpe).to(torch.float32) - - output = torch.zeros( - (num_tokens, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device - ) - lse = torch.full((num_tokens, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - for t in range(num_tokens): - indices = sparse_indices[t] - - valid_mask = indices != -1 - valid_indices = indices[valid_mask] - - if valid_indices.numel() == 0: - output[t].zero_() - continue - - tok_idx = valid_indices.to(torch.long) - - Kc = Kc_all[tok_idx] - Kp = Kp_all[tok_idx] - qn = q_nope[t].to(torch.float32) - qp = q_pe[t].to(torch.float32) - - logits = (qn @ Kc.T) + (qp @ Kp.T) - logits_scaled = logits * sm_scale - - lse[t] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - attn = torch.softmax(logits_scaled, dim=-1) - out = attn @ Kc - output[t] = out.to(torch.bfloat16) +def generate_random_inputs(num_tokens, topk=TOPK, device="cuda"): + """Generate random inputs for DSA sparse attention testing.""" + num_pages = max(num_tokens * 2, 1024) - return {"output": output, "lse": lse} - - -def generate_random_inputs( - num_tokens, - num_qo_heads=NUM_QO_HEADS, - head_dim_ckv=HEAD_DIM_CKV, - head_dim_kpe=HEAD_DIM_KPE, - topk=TOPK, - device="cuda", -): - """Generate random inputs for DSA sparse attention testing with page_size=64.""" - total_kv_tokens = max(num_tokens * 4, 2048) - num_pages = (total_kv_tokens + PAGE_SIZE - 1) // PAGE_SIZE - - total_tokens_in_cache = num_pages * PAGE_SIZE sparse_indices = torch.randint( - 0, total_tokens_in_cache, (num_tokens, topk), dtype=torch.int32, device=device + 0, num_pages, (num_tokens, topk), dtype=torch.int32, device=device ) q_nope = torch.randn( - num_tokens, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device + num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device ) - q_pe = torch.randn(num_tokens, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) + q_pe = torch.randn(num_tokens, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - ckv_cache = torch.randn(num_pages, PAGE_SIZE, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, PAGE_SIZE, head_dim_kpe, dtype=torch.bfloat16, device=device) + ckv_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) + sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) return { "q_nope": q_nope, @@ -121,6 +51,70 @@ def generate_random_inputs( } +def test_correctness(num_tokens=64, topk=TOPK, atol=1e-2, rtol=5e-2): + """Test correctness of DSA sparse attention reference implementation against FlashInfer.""" + print(f"\n{'='*60}") + print(f"Testing DSA Sparse Attention (ps64) num_tokens={num_tokens}, topk={topk}") + print(f"{'='*60}") + + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cpu": + print("WARNING: CUDA not available, skipping test") + return True + + inputs = generate_random_inputs(num_tokens, topk=topk, device=device) + + print("Running reference implementation from definition...") + ref_o, ref_lse = run( + inputs["q_nope"], + inputs["q_pe"], + inputs["ckv_cache"], + inputs["kpe_cache"], + inputs["sparse_indices"], + inputs["sm_scale"], + ) + + # Prepare FlashInfer inputs (trtllm-gen format) + # Query: concatenate q_nope and q_pe, add seq_len dim + query = torch.cat([inputs["q_nope"], inputs["q_pe"]], dim=-1).unsqueeze(1) + # KV cache: concatenate ckv and kpe caches + kv_cache = torch.cat([inputs["ckv_cache"], inputs["kpe_cache"]], dim=-1) + # Block tables: add seq_len dim to sparse_indices + block_tables = inputs["sparse_indices"].unsqueeze(1) + workspace = torch.zeros(16 * 1024 * 1024, dtype=torch.uint8, device=device) + total_tokens = inputs["num_pages"] * PAGE_SIZE + seq_lens = torch.full((num_tokens,), total_tokens, dtype=torch.int32, device=device) + + print("Running FlashInfer...") + fi_output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace, + qk_nope_head_dim=128, # QK_NOPE_HEAD_DIM + kv_lora_rank=HEAD_DIM_CKV, + qk_rope_head_dim=HEAD_DIM_KPE, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=total_tokens, + sparse_mla_top_k=topk, + bmm1_scale=inputs["sm_scale"].item(), + ) + fi_output = fi_output.squeeze(1) # Remove seq_len dim + + print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") + + all_close = output_metrics.all_close + + if all_close: + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") + else: + print(f"\n✗ FAILED: Outputs differ beyond tolerance") + + return all_close + + def test_output_shape(num_tokens=64, topk=TOPK): """Test that reference produces correct output shapes.""" device = "cuda" if torch.cuda.is_available() else "cpu" @@ -135,8 +129,7 @@ def test_output_shape(num_tokens=64, topk=TOPK): inputs["sm_scale"], ) - output = result["output"] - lse = result["lse"] + output, lse = result assert output.shape == (num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV) assert lse.shape == (num_tokens, NUM_QO_HEADS) @@ -145,7 +138,7 @@ def test_output_shape(num_tokens=64, topk=TOPK): def test_padding_handling(num_tokens=64, topk=TOPK): """Test that padding (-1 indices) are handled correctly.""" device = "cuda" if torch.cuda.is_available() else "cpu" - num_pages = 64 + num_pages = 1000 q_nope = torch.randn( num_tokens, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device @@ -155,7 +148,9 @@ def test_padding_handling(num_tokens=64, topk=TOPK): kpe_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) sm_scale = torch.tensor(1.0 / np.sqrt(128 + HEAD_DIM_KPE), dtype=torch.float32, device=device) + # Create sparse indices with varying amounts of padding per token sparse_indices = torch.full((num_tokens, topk), -1, dtype=torch.int32, device=device) + total_tokens_in_cache = num_pages * PAGE_SIZE for t in range(num_tokens): @@ -166,21 +161,25 @@ def test_padding_handling(num_tokens=64, topk=TOPK): ) result = run(q_nope, q_pe, ckv_cache, kpe_cache, sparse_indices, sm_scale) - output = result["output"] - lse = result["lse"] + output, lse = result + # Verify outputs are valid assert not torch.isnan(output).any() assert not torch.isinf(output).any() assert not torch.isnan(lse).any() if __name__ == "__main__": - print("Testing DSA Sparse Attention Reference (page_size=64)") + print("Testing DSA Sparse Attention Reference (page_size=64, from definition)") print( f"Constants: h={NUM_QO_HEADS}, ckv={HEAD_DIM_CKV}, kpe={HEAD_DIM_KPE}, ps={PAGE_SIZE}, topk={TOPK}" ) print("=" * 70) + test_configs = [(16, TOPK), (32, TOPK), (64, TOPK), (128, TOPK)] + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nCorrectness: {passed}/{len(test_configs)} tests passed\n{'='*60}") + test_output_shape() print("test_output_shape: PASSED") diff --git a/flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py b/flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py index b0abef36..0bc6ad81 100644 --- a/flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py +++ b/flashinfer_trace/tests/references/test_gdn_decode_qk16_v32_d128_k_last.py @@ -7,35 +7,13 @@ """ import math -from pathlib import Path import pytest import torch import torch.nn.functional as F from flashinfer.gdn_decode import gated_delta_rule_decode_pretranspose from flashinfer.utils import get_compute_capability - -from flashinfer_bench.data import Definition, load_json_file - -# Paths -DEFINITIONS_DIR = Path(__file__).parent.parent.parent / "definitions" - - -def load_definition(name: str) -> Definition: - """Load a definition by name from definitions directory.""" - for op_dir in DEFINITIONS_DIR.iterdir(): - if op_dir.is_dir(): - def_file = op_dir / f"{name}.json" - if def_file.exists(): - return load_json_file(Definition, def_file) - raise FileNotFoundError(f"Definition {name} not found in {DEFINITIONS_DIR}") - - -def compile_reference(reference_code: str): - """Compile reference implementation to callable function.""" - namespace = {"torch": torch, "math": math, "F": F} - exec(reference_code, namespace) - return namespace["run"] +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics def _skip_if_not_sm90_or_later(): @@ -48,10 +26,9 @@ def _skip_if_not_sm90_or_later(): def run_kernel(q, k, v, state, A_log, a, dt_bias, b, scale): """Run FlashInfer kernel (pretranspose version uses k-last layout).""" B, T, num_q_heads, K = q.shape - num_v_heads = v.shape[2] # Pre-allocate output - output = torch.empty(B, T, num_v_heads, K, dtype=q.dtype, device=q.device) + output = torch.empty(B, T, v.shape[2], K, dtype=q.dtype, device=q.device) # Call kernel out, new_state = gated_delta_rule_decode_pretranspose( @@ -130,9 +107,8 @@ def test_correctness(batch_size=4, atol=5e-3, rtol=5e-3): print(f"Testing GDN decode k-last, batch_size={batch_size}") print(f"{'='*60}") - # Load definition and compile reference - definition = load_definition("gdn_decode_qk16_v32_d128_k_last") - run = compile_reference(definition.reference) + # Load reference from definition + run = get_reference_run("gdn_decode_qk16_v32_d128_k_last") device = "cuda" inputs = generate_random_inputs(batch_size=batch_size, device=device) @@ -166,70 +142,15 @@ def test_correctness(batch_size=4, atol=5e-3, rtol=5e-3): inputs["scale"], ) - # Compare outputs + # Compare outputs using test_utils print("\nComparing outputs...") + output_metrics = compare_tensors(ref_output, kernel_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") + + state_metrics = compare_tensors(ref_new_state, kernel_new_state, atol=atol, rtol=rtol) + print_comparison_metrics(state_metrics, tensor_name="State tensor") - ref_o_f32 = ref_output.float() - kernel_o_f32 = kernel_output.float() - - # Absolute difference metrics - abs_diff_o = torch.abs(ref_o_f32 - kernel_o_f32) - max_abs_diff_o = abs_diff_o.max().item() - mean_abs_diff_o = abs_diff_o.mean().item() - - # Relative difference metrics (avoid division by zero) - rel_diff_o = abs_diff_o / (torch.abs(ref_o_f32) + 1e-10) - max_rel_diff_o = rel_diff_o.max().item() - mean_rel_diff_o = rel_diff_o.mean().item() - - # Cosine similarity - ref_flat = ref_o_f32.reshape(-1) - kernel_flat = kernel_o_f32.reshape(-1) - cosine_sim_o = F.cosine_similarity(ref_flat.unsqueeze(0), kernel_flat.unsqueeze(0)).item() - - # Mean Squared Error - mse_o = ((ref_o_f32 - kernel_o_f32) ** 2).mean().item() - - print("\nOutput tensor comparison:") - print(f" Max absolute difference: {max_abs_diff_o:.6e}") - print(f" Max relative difference: {max_rel_diff_o:.6e}") - print(f" Mean absolute difference: {mean_abs_diff_o:.6e}") - print(f" Mean relative difference: {mean_rel_diff_o:.6e}") - print(f" Cosine similarity: {cosine_sim_o:.6f}") - print(f" MSE: {mse_o:.6e}") - - # State comparison - abs_diff_s = torch.abs(ref_new_state - kernel_new_state) - max_abs_diff_s = abs_diff_s.max().item() - mean_abs_diff_s = abs_diff_s.mean().item() - - # State relative difference - rel_diff_s = abs_diff_s / (torch.abs(ref_new_state) + 1e-10) - max_rel_diff_s = rel_diff_s.max().item() - mean_rel_diff_s = rel_diff_s.mean().item() - - # State cosine similarity - ref_state_flat = ref_new_state.reshape(-1) - kernel_state_flat = kernel_new_state.reshape(-1) - cosine_sim_s = F.cosine_similarity( - ref_state_flat.unsqueeze(0), kernel_state_flat.unsqueeze(0) - ).item() - - # State MSE - mse_s = ((ref_new_state - kernel_new_state) ** 2).mean().item() - - print("\nState tensor comparison:") - print(f" Max absolute difference: {max_abs_diff_s:.6e}") - print(f" Max relative difference: {max_rel_diff_s:.6e}") - print(f" Mean absolute difference: {mean_abs_diff_s:.6e}") - print(f" Mean relative difference: {mean_rel_diff_s:.6e}") - print(f" Cosine similarity: {cosine_sim_s:.6f}") - print(f" MSE: {mse_s:.6e}") - - output_close = torch.allclose(ref_o_f32, kernel_o_f32, atol=atol, rtol=rtol) - state_close = torch.allclose(ref_new_state, kernel_new_state, atol=atol, rtol=rtol) - - if output_close and state_close: + if output_metrics.all_close and state_metrics.all_close: print(f"\n✓ PASSED (atol={atol}, rtol={rtol})") return True else: @@ -242,9 +163,8 @@ def test_gdn_decode_k_last(batch_size: int): """Pytest parametrized test for various batch sizes.""" _skip_if_not_sm90_or_later() - # Load definition and compile reference - definition = load_definition("gdn_decode_qk16_v32_d128_k_last") - run = compile_reference(definition.reference) + # Load reference from definition + run = get_reference_run("gdn_decode_qk16_v32_d128_k_last") device = "cuda" inputs = generate_random_inputs(batch_size=batch_size, device=device) diff --git a/flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py b/flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py index 19267302..00078378 100644 --- a/flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py +++ b/flashinfer_trace/tests/references/test_gdn_prefill_qk16_v32_d128_k_last.py @@ -8,33 +8,11 @@ import math import sys -from pathlib import Path import pytest import torch import torch.nn.functional as F - -from flashinfer_bench.data import Definition, load_json_file - -# Paths -DEFINITIONS_DIR = Path(__file__).parent.parent.parent / "definitions" - - -def load_definition(name: str) -> Definition: - """Load a definition by name from definitions directory.""" - for op_dir in DEFINITIONS_DIR.iterdir(): - if op_dir.is_dir(): - def_file = op_dir / f"{name}.json" - if def_file.exists(): - return load_json_file(Definition, def_file) - raise FileNotFoundError(f"Definition {name} not found in {DEFINITIONS_DIR}") - - -def compile_reference(reference_code: str): - """Compile reference implementation to callable function.""" - namespace = {"torch": torch, "math": math, "F": F} - exec(reference_code, namespace) - return namespace["run"] +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics def get_cuda_capability(): @@ -66,9 +44,8 @@ def compute_gates(A_log, a, dt_bias, b): return g, beta -# Load definition and compile reference -definition = load_definition("gdn_prefill_qk16_v32_d128_k_last") -reference_gdn_prefill = compile_reference(definition.reference) +# Load reference from definition +reference_gdn_prefill = get_reference_run("gdn_prefill_qk16_v32_d128_k_last") @requires_cuda @@ -125,64 +102,22 @@ def test_gdn_prefill_correctness(batch_size: int, seq_len: int): cu_seqlens=cu_seqlens, ) - # Output comparison metrics - ref_o_f32 = ref_output.float() - fi_o_f32 = fi_output.float() - - abs_diff_o = torch.abs(ref_o_f32 - fi_o_f32) - max_abs_diff_o = abs_diff_o.max().item() - mean_abs_diff_o = abs_diff_o.mean().item() - - rel_diff_o = abs_diff_o / (torch.abs(ref_o_f32) + 1e-10) - max_rel_diff_o = rel_diff_o.max().item() - mean_rel_diff_o = rel_diff_o.mean().item() - - ref_flat = ref_o_f32.reshape(-1) - fi_flat = fi_o_f32.reshape(-1) - cosine_sim_o = F.cosine_similarity(ref_flat.unsqueeze(0), fi_flat.unsqueeze(0)).item() - - mse_o = ((ref_o_f32 - fi_o_f32) ** 2).mean().item() - - # State comparison metrics - abs_diff_s = torch.abs(ref_new_state - fi_new_state) - max_abs_diff_s = abs_diff_s.max().item() - mean_abs_diff_s = abs_diff_s.mean().item() - - rel_diff_s = abs_diff_s / (torch.abs(ref_new_state) + 1e-10) - max_rel_diff_s = rel_diff_s.max().item() - mean_rel_diff_s = rel_diff_s.mean().item() - - ref_state_flat = ref_new_state.reshape(-1) - fi_state_flat = fi_new_state.reshape(-1) - cosine_sim_s = F.cosine_similarity( - ref_state_flat.unsqueeze(0), fi_state_flat.unsqueeze(0) - ).item() + # Compare using test_utils + atol = 0.1 + print(f"\nBatch={batch_size}, SeqLen={seq_len}") - mse_s = ((ref_new_state - fi_new_state) ** 2).mean().item() + output_metrics = compare_tensors(ref_output, fi_output, atol=atol, rtol=atol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - print(f"\nBatch={batch_size}, SeqLen={seq_len}") - print("\nOutput tensor comparison:") - print(f" Max absolute difference: {max_abs_diff_o:.6e}") - print(f" Max relative difference: {max_rel_diff_o:.6e}") - print(f" Mean absolute difference: {mean_abs_diff_o:.6e}") - print(f" Mean relative difference: {mean_rel_diff_o:.6e}") - print(f" Cosine similarity: {cosine_sim_o:.6f}") - print(f" MSE: {mse_o:.6e}") - - print("\nState tensor comparison:") - print(f" Max absolute difference: {max_abs_diff_s:.6e}") - print(f" Max relative difference: {max_rel_diff_s:.6e}") - print(f" Mean absolute difference: {mean_abs_diff_s:.6e}") - print(f" Mean relative difference: {mean_rel_diff_s:.6e}") - print(f" Cosine similarity: {cosine_sim_s:.6f}") - print(f" MSE: {mse_s:.6e}") - - output_max_err = max_abs_diff_o - state_max_err = max_abs_diff_s + state_metrics = compare_tensors(ref_new_state, fi_new_state, atol=atol, rtol=atol) + print_comparison_metrics(state_metrics, tensor_name="State tensor") - atol = 0.1 - assert output_max_err < atol, f"Output max error {output_max_err} exceeds tolerance" - assert state_max_err < atol, f"State max error {state_max_err} exceeds tolerance" + assert ( + output_metrics.max_abs_diff < atol + ), f"Output max error {output_metrics.max_abs_diff} exceeds tolerance" + assert ( + state_metrics.max_abs_diff < atol + ), f"State max error {state_metrics.max_abs_diff} exceeds tolerance" @requires_cuda @@ -245,64 +180,22 @@ def test_gdn_prefill_with_initial_state(): cu_seqlens=cu_seqlens, ) - # Output comparison metrics - ref_o_f32 = ref_output.float() - fi_o_f32 = fi_output.float() - - abs_diff_o = torch.abs(ref_o_f32 - fi_o_f32) - max_abs_diff_o = abs_diff_o.max().item() - mean_abs_diff_o = abs_diff_o.mean().item() - - rel_diff_o = abs_diff_o / (torch.abs(ref_o_f32) + 1e-10) - max_rel_diff_o = rel_diff_o.max().item() - mean_rel_diff_o = rel_diff_o.mean().item() - - ref_flat = ref_o_f32.reshape(-1) - fi_flat = fi_o_f32.reshape(-1) - cosine_sim_o = F.cosine_similarity(ref_flat.unsqueeze(0), fi_flat.unsqueeze(0)).item() - - mse_o = ((ref_o_f32 - fi_o_f32) ** 2).mean().item() - - # State comparison metrics - abs_diff_s = torch.abs(ref_new_state - fi_new_state) - max_abs_diff_s = abs_diff_s.max().item() - mean_abs_diff_s = abs_diff_s.mean().item() - - rel_diff_s = abs_diff_s / (torch.abs(ref_new_state) + 1e-10) - max_rel_diff_s = rel_diff_s.max().item() - mean_rel_diff_s = rel_diff_s.mean().item() - - ref_state_flat = ref_new_state.reshape(-1) - fi_state_flat = fi_new_state.reshape(-1) - cosine_sim_s = F.cosine_similarity( - ref_state_flat.unsqueeze(0), fi_state_flat.unsqueeze(0) - ).item() - - mse_s = ((ref_new_state - fi_new_state) ** 2).mean().item() - - print(f"\nWith initial state:") - print("\nOutput tensor comparison:") - print(f" Max absolute difference: {max_abs_diff_o:.6e}") - print(f" Max relative difference: {max_rel_diff_o:.6e}") - print(f" Mean absolute difference: {mean_abs_diff_o:.6e}") - print(f" Mean relative difference: {mean_rel_diff_o:.6e}") - print(f" Cosine similarity: {cosine_sim_o:.6f}") - print(f" MSE: {mse_o:.6e}") - - print("\nState tensor comparison:") - print(f" Max absolute difference: {max_abs_diff_s:.6e}") - print(f" Max relative difference: {max_rel_diff_s:.6e}") - print(f" Mean absolute difference: {mean_abs_diff_s:.6e}") - print(f" Mean relative difference: {mean_rel_diff_s:.6e}") - print(f" Cosine similarity: {cosine_sim_s:.6f}") - print(f" MSE: {mse_s:.6e}") - - output_max_err = max_abs_diff_o - state_max_err = max_abs_diff_s - + # Compare using test_utils atol = 0.1 - assert output_max_err < atol, f"Output max error {output_max_err} exceeds tolerance" - assert state_max_err < atol, f"State max error {state_max_err} exceeds tolerance" + print("\nWith initial state:") + + output_metrics = compare_tensors(ref_output, fi_output, atol=atol, rtol=atol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") + + state_metrics = compare_tensors(ref_new_state, fi_new_state, atol=atol, rtol=atol) + print_comparison_metrics(state_metrics, tensor_name="State tensor") + + assert ( + output_metrics.max_abs_diff < atol + ), f"Output max error {output_metrics.max_abs_diff} exceeds tolerance" + assert ( + state_metrics.max_abs_diff < atol + ), f"State max error {state_metrics.max_abs_diff} exceeds tolerance" @requires_cuda @@ -357,64 +250,22 @@ def test_gdn_prefill_variable_seqlen(): cu_seqlens=cu_seqlens, ) - # Output comparison metrics - ref_o_f32 = ref_output.float() - fi_o_f32 = fi_output.float() - - abs_diff_o = torch.abs(ref_o_f32 - fi_o_f32) - max_abs_diff_o = abs_diff_o.max().item() - mean_abs_diff_o = abs_diff_o.mean().item() - - rel_diff_o = abs_diff_o / (torch.abs(ref_o_f32) + 1e-10) - max_rel_diff_o = rel_diff_o.max().item() - mean_rel_diff_o = rel_diff_o.mean().item() - - ref_flat = ref_o_f32.reshape(-1) - fi_flat = fi_o_f32.reshape(-1) - cosine_sim_o = F.cosine_similarity(ref_flat.unsqueeze(0), fi_flat.unsqueeze(0)).item() - - mse_o = ((ref_o_f32 - fi_o_f32) ** 2).mean().item() - - # State comparison metrics - abs_diff_s = torch.abs(ref_new_state - fi_new_state) - max_abs_diff_s = abs_diff_s.max().item() - mean_abs_diff_s = abs_diff_s.mean().item() - - rel_diff_s = abs_diff_s / (torch.abs(ref_new_state) + 1e-10) - max_rel_diff_s = rel_diff_s.max().item() - mean_rel_diff_s = rel_diff_s.mean().item() - - ref_state_flat = ref_new_state.reshape(-1) - fi_state_flat = fi_new_state.reshape(-1) - cosine_sim_s = F.cosine_similarity( - ref_state_flat.unsqueeze(0), fi_state_flat.unsqueeze(0) - ).item() + # Compare using test_utils + atol = 0.1 + print(f"\nVariable seqlens={seq_lens}:") - mse_s = ((ref_new_state - fi_new_state) ** 2).mean().item() + output_metrics = compare_tensors(ref_output, fi_output, atol=atol, rtol=atol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - print(f"\nVariable seqlens={seq_lens}:") - print("\nOutput tensor comparison:") - print(f" Max absolute difference: {max_abs_diff_o:.6e}") - print(f" Max relative difference: {max_rel_diff_o:.6e}") - print(f" Mean absolute difference: {mean_abs_diff_o:.6e}") - print(f" Mean relative difference: {mean_rel_diff_o:.6e}") - print(f" Cosine similarity: {cosine_sim_o:.6f}") - print(f" MSE: {mse_o:.6e}") - - print("\nState tensor comparison:") - print(f" Max absolute difference: {max_abs_diff_s:.6e}") - print(f" Max relative difference: {max_rel_diff_s:.6e}") - print(f" Mean absolute difference: {mean_abs_diff_s:.6e}") - print(f" Mean relative difference: {mean_rel_diff_s:.6e}") - print(f" Cosine similarity: {cosine_sim_s:.6f}") - print(f" MSE: {mse_s:.6e}") - - output_max_err = max_abs_diff_o - state_max_err = max_abs_diff_s + state_metrics = compare_tensors(ref_new_state, fi_new_state, atol=atol, rtol=atol) + print_comparison_metrics(state_metrics, tensor_name="State tensor") - atol = 0.1 - assert output_max_err < atol, f"Output max error {output_max_err} exceeds tolerance" - assert state_max_err < atol, f"State max error {state_max_err} exceeds tolerance" + assert ( + output_metrics.max_abs_diff < atol + ), f"Output max error {output_metrics.max_abs_diff} exceeds tolerance" + assert ( + state_metrics.max_abs_diff < atol + ), f"State max error {state_metrics.max_abs_diff} exceeds tolerance" if __name__ == "__main__": diff --git a/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv4_d128_ps1.py b/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv4_d128_ps1.py index 57683955..0bd02852 100644 --- a/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv4_d128_ps1.py +++ b/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv4_d128_ps1.py @@ -1,97 +1,31 @@ -import math +""" +Test GQA paged decode h32_kv4_d128_ps1 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" import flashinfer import numpy as np import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_paged_decode_h32_kv4_d128_ps1") -@torch.no_grad() -def run(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale): - batch_size, num_qo_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_cache.shape - len_indptr = kv_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 4 - assert head_dim == 128 - assert page_size == 1 - - # Check constraints - assert len_indptr == batch_size + 1 - assert num_kv_indices == kv_indptr[-1].item() - - device = q.device - - output = torch.zeros((batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - k_cache_flat = k_cache.squeeze(1).to(torch.float32) # [num_pages, num_kv_heads, head_dim] - v_cache_flat = v_cache.squeeze(1).to(torch.float32) # [num_pages, num_kv_heads, head_dim] - - for b in range(batch_size): - page_start = int(kv_indptr[b].item()) - page_end = int(kv_indptr[b + 1].item()) - - if page_start >= page_end: - # No KV cache for this batch element - output[b].zero_() - continue - - # Pages are the token indices for page_size=1 - token_indices = kv_indices[page_start:page_end].to(torch.long) - # Number of tokens is the number of pages for page_size=1 - num_tokens = token_indices.shape[0] - - if num_tokens == 0: - output[b].zero_() - continue - - # Get Q, K, V for this batch - k_batch = k_cache_flat[token_indices] # [num_tokens, num_kv_heads, head_dim] - v_batch = v_cache_flat[token_indices] # [num_tokens, num_kv_heads, head_dim] - q_batch = q[b].to(torch.float32) # [num_qo_heads, head_dim] +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 4 +HEAD_DIM = 128 +PAGE_SIZE = 1 - for h in range(num_qo_heads): - # Find corresponding KV head for GQA - kv_head = h // gqa_ratio - q_head = q_batch[h] # [head_dim] - k_head = k_batch[:, kv_head] # [num_tokens, head_dim] - v_head = v_batch[:, kv_head] # [num_tokens, head_dim] - - logits = torch.matmul(q_head, k_head.T) # [num_tokens] - logits_scaled = logits * sm_scale - - # Compute 2-base LSE - lse[b, h] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) # [num_tokens] - out_head = torch.matmul(attn, v_head) # [head_dim] - output[b, h] = out_head.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_seq_len, - num_attention_heads=32, - num_key_value_heads=4, - head_dim=128, - page_size=1, - device="cuda", -): +def generate_random_inputs(batch_size, max_seq_len, device="cuda"): """Generate random inputs for testing.""" - # Generate random sequence lengths for each batch seq_lens = torch.randint(1, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device) - # Calculate total pages needed - # Since page_size = 1, num_pages = total_tokens + # Calculate total pages needed (page_size=1 means num_pages = total_tokens) total_pages_needed = seq_lens.sum().item() # Generate kv_indptr based on sequence lengths @@ -99,27 +33,25 @@ def generate_random_inputs( kv_indptr[1:] = torch.cumsum(seq_lens, dim=0) # Generate kv_indices (page indices for each sequence) - # We'll use consecutive pages for simplicity kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) # For page_size=1, last page always has 1 token kv_last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) # Generate query tensor - q = torch.randn(batch_size, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) + q = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) # Generate K and V caches - # Add some extra pages to simulate a real scenario num_pages = total_pages_needed + 100 k_cache = torch.randn( - num_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + num_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) v_cache = torch.randn( - num_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + num_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) # Generate attention parameters - sm_scale = 1.0 / np.sqrt(head_dim) + sm_scale = 1.0 / np.sqrt(HEAD_DIM) sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) return { @@ -145,28 +77,14 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 4 - head_dim = 128 - page_size = 1 - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_seq_len, - num_attention_heads, - num_key_value_heads, - head_dim, - page_size, - device, - ) + inputs = generate_random_inputs(batch_size, max_seq_len, device) print(f"Generated sequences with lengths: {inputs['seq_lens'].cpu().numpy()}") print(f"Total pages used: {inputs['kv_indices'].shape[0]}") - # Run reference implementation - print("\nRunning reference implementation...") + # Run reference implementation from definition + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k_cache"], @@ -181,7 +99,7 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout="NHD" # Matches our cache layout + workspace_buffer, kv_layout="NHD" ) # Plan the attention computation @@ -189,10 +107,10 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): indptr=inputs["kv_indptr"], indices=inputs["kv_indices"], last_page_len=inputs["kv_last_page_len"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim=head_dim, - page_size=page_size, + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + page_size=PAGE_SIZE, pos_encoding_mode="NONE", q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, @@ -201,123 +119,33 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): # Run FlashInfer print("Running FlashInfer...") - # FlashInfer expects tuple of (k_cache, v_cache) for paged_kv_cache fi_output, fi_lse = decode_wrapper.run( inputs["q"], (inputs["k_cache"], inputs["v_cache"]), return_lse=True ) # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - # Convert to float32 for comparison - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - # Compute errors for output tensor - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - # Compute cosine similarity and MSE for output tensor - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - # Compare LSE values - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - # Check if outputs match within tolerance - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") + + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") else: print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - if not output_close: - # Find indices with largest errors for debugging - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - # Convert flat index back to 3D indices - batch_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}, {dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - # Find LSE errors - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - batch_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) - return all_close def main(): """Run comprehensive tests.""" - print("Testing Batch GQA Paged Decode Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, max_seq_len) - (1, 16), # Single batch - (4, 32), # Small batch - (8, 64), # Medium batch - (16, 128), # Large batch - ] + print("Testing Batch GQA Paged Decode Reference Implementation (from definition)") + + test_configs = [(1, 16), (4, 32), (8, 64), (16, 128)] passed = 0 total = len(test_configs) diff --git a/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv4_d128_ps64.py b/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv4_d128_ps64.py index 3c75f133..10f7ed47 100644 --- a/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv4_d128_ps64.py +++ b/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv4_d128_ps64.py @@ -1,119 +1,32 @@ -import math +""" +Test GQA paged decode h32_kv4_d128_ps64 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" import flashinfer import numpy as np import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_paged_decode_h32_kv4_d128_ps64") -@torch.no_grad() -def run(q, k_cache, v_cache, kv_indptr, kv_indices, kv_last_page_len, sm_scale): - batch_size, num_qo_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_cache.shape - len_indptr = kv_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 4 - assert head_dim == 128 - assert page_size == 64 - - # Check constraints - assert len_indptr == batch_size + 1 - assert num_kv_indices == kv_indptr[-1].item() - - device = q.device - - output = torch.zeros((batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - k_cache_f32 = k_cache.to(torch.float32) - v_cache_f32 = v_cache.to(torch.float32) - - for b in range(batch_size): - page_start = int(kv_indptr[b].item()) - page_end = int(kv_indptr[b + 1].item()) - last_page_len = int(kv_last_page_len[b].item()) - - if page_start >= page_end: - output[b].zero_() - continue - - page_ids = kv_indices[page_start:page_end].to(torch.long) - num_pages_for_seq = page_ids.shape[0] - - if num_pages_for_seq == 0: - output[b].zero_() - continue - - num_full_pages = num_pages_for_seq - 1 - total_tokens = num_full_pages * page_size + last_page_len - - if total_tokens == 0: - output[b].zero_() - continue - - k_batch = torch.zeros( - (total_tokens, num_kv_heads, head_dim), dtype=torch.float32, device=device - ) - v_batch = torch.zeros( - (total_tokens, num_kv_heads, head_dim), dtype=torch.float32, device=device - ) - - token_idx = 0 - for p_idx, page_id in enumerate(page_ids): - if p_idx < num_full_pages: - k_batch[token_idx : token_idx + page_size] = k_cache_f32[page_id] - v_batch[token_idx : token_idx + page_size] = v_cache_f32[page_id] - token_idx += page_size - else: - k_batch[token_idx : token_idx + last_page_len] = k_cache_f32[ - page_id, :last_page_len - ] - v_batch[token_idx : token_idx + last_page_len] = v_cache_f32[ - page_id, :last_page_len - ] - token_idx += last_page_len - - q_batch = q[b].to(torch.float32) - - for h in range(num_qo_heads): - kv_head = h // gqa_ratio - - q_head = q_batch[h] - k_head = k_batch[:, kv_head] - v_head = v_batch[:, kv_head] - - logits = torch.matmul(q_head, k_head.T) - logits_scaled = logits * sm_scale - - lse[b, h] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) - out_head = torch.matmul(attn, v_head) - output[b, h] = out_head.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_seq_len, - num_attention_heads=32, - num_key_value_heads=4, - head_dim=128, - page_size=64, - device="cuda", -): - """Generate random inputs for testing.""" +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 4 +HEAD_DIM = 128 +PAGE_SIZE = 64 + +def generate_random_inputs(batch_size, max_seq_len, device="cuda"): + """Generate random inputs for testing.""" # Generate random sequence lengths for each batch seq_lens = torch.randint(1, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device) # Calculate pages needed for each sequence - pages_per_seq = (seq_lens + page_size - 1) // page_size # Ceiling division + pages_per_seq = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE # Ceiling division total_pages_needed = pages_per_seq.sum().item() # Generate kv_indptr based on pages per sequence @@ -124,22 +37,22 @@ def generate_random_inputs( kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) # Calculate last_page_len for each sequence - kv_last_page_len = ((seq_lens - 1) % page_size) + 1 + kv_last_page_len = ((seq_lens - 1) % PAGE_SIZE) + 1 # Generate query tensor - q = torch.randn(batch_size, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) + q = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) # Generate K and V caches num_pages = total_pages_needed + 100 k_cache = torch.randn( - num_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + num_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) v_cache = torch.randn( - num_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + num_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) # Generate attention parameters - sm_scale = 1.0 / np.sqrt(head_dim) + sm_scale = 1.0 / np.sqrt(HEAD_DIM) sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) return { @@ -165,29 +78,15 @@ def test_correctness(batch_size=4, max_seq_len=256, atol=1e-2, rtol=5e-2): print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 4 - head_dim = 128 - page_size = 64 - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_seq_len, - num_attention_heads, - num_key_value_heads, - head_dim, - page_size, - device, - ) + inputs = generate_random_inputs(batch_size, max_seq_len, device) print(f"Generated sequences with lengths: {inputs['seq_lens'].cpu().numpy()}") print(f"Last page lengths: {inputs['kv_last_page_len'].cpu().numpy()}") print(f"Total pages used: {inputs['kv_indices'].shape[0]}") - # Run reference implementation - print("\nRunning reference implementation...") + # Run reference implementation from definition (page_size=64 includes kv_last_page_len) + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k_cache"], @@ -211,10 +110,10 @@ def test_correctness(batch_size=4, max_seq_len=256, atol=1e-2, rtol=5e-2): indptr=inputs["kv_indptr"], indices=inputs["kv_indices"], last_page_len=inputs["kv_last_page_len"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim=head_dim, - page_size=page_size, + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + page_size=PAGE_SIZE, pos_encoding_mode="NONE", q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, @@ -229,99 +128,25 @@ def test_correctness(batch_size=4, max_seq_len=256, atol=1e-2, rtol=5e-2): # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") else: print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - if not output_close: - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - batch_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}, {dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - batch_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) - return all_close def main(): """Run comprehensive tests.""" - print("Testing Batch GQA Paged Decode Reference Implementation (page_size=64)") + print("Testing Batch GQA Paged Decode Reference Implementation (page_size=64, from definition)") test_configs = [(1, 64), (4, 128), (8, 256), (16, 512)] diff --git a/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv8_d128_ps1.py b/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv8_d128_ps1.py index 9230b5ed..4e687f53 100644 --- a/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv8_d128_ps1.py +++ b/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv8_d128_ps1.py @@ -1,97 +1,31 @@ -import math +""" +Test GQA paged decode h32_kv8_d128_ps1 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" import flashinfer import numpy as np import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_paged_decode_h32_kv8_d128_ps1") -@torch.no_grad() -def run(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale): - batch_size, num_qo_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_cache.shape - len_indptr = kv_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 8 - assert head_dim == 128 - assert page_size == 1 - - # Check constraints - assert len_indptr == batch_size + 1 - assert num_kv_indices == kv_indptr[-1].item() - - device = q.device - - output = torch.zeros((batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - k_cache_flat = k_cache.squeeze(1).to(torch.float32) # [num_pages, num_kv_heads, head_dim] - v_cache_flat = v_cache.squeeze(1).to(torch.float32) # [num_pages, num_kv_heads, head_dim] - - for b in range(batch_size): - page_start = int(kv_indptr[b].item()) - page_end = int(kv_indptr[b + 1].item()) - - if page_start >= page_end: - # No KV cache for this batch element - output[b].zero_() - continue - - # Pages are the token indices for page_size=1 - token_indices = kv_indices[page_start:page_end].to(torch.long) - # Number of tokens is the number of pages for page_size=1 - num_tokens = token_indices.shape[0] - - if num_tokens == 0: - output[b].zero_() - continue - - # Get Q, K, V for this batch - k_batch = k_cache_flat[token_indices] # [num_tokens, num_kv_heads, head_dim] - v_batch = v_cache_flat[token_indices] # [num_tokens, num_kv_heads, head_dim] - q_batch = q[b].to(torch.float32) # [num_qo_heads, head_dim] +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +PAGE_SIZE = 1 - for h in range(num_qo_heads): - # Find corresponding KV head for GQA - kv_head = h // gqa_ratio - q_head = q_batch[h] # [head_dim] - k_head = k_batch[:, kv_head] # [num_tokens, head_dim] - v_head = v_batch[:, kv_head] # [num_tokens, head_dim] - - logits = torch.matmul(q_head, k_head.T) # [num_tokens] - logits_scaled = logits * sm_scale - - # Compute 2-base LSE - lse[b, h] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) # [num_tokens] - out_head = torch.matmul(attn, v_head) # [head_dim] - output[b, h] = out_head.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_seq_len, - num_attention_heads=32, - num_key_value_heads=8, - head_dim=128, - page_size=1, - device="cuda", -): +def generate_random_inputs(batch_size, max_seq_len, device="cuda"): """Generate random inputs for testing.""" - # Generate random sequence lengths for each batch seq_lens = torch.randint(1, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device) - # Calculate total pages needed - # Since page_size = 1, num_pages = total_tokens + # Calculate total pages needed (page_size=1 means num_pages = total_tokens) total_pages_needed = seq_lens.sum().item() # Generate kv_indptr based on sequence lengths @@ -99,27 +33,25 @@ def generate_random_inputs( kv_indptr[1:] = torch.cumsum(seq_lens, dim=0) # Generate kv_indices (page indices for each sequence) - # We'll use consecutive pages for simplicity kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) # For page_size=1, last page always has 1 token kv_last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) # Generate query tensor - q = torch.randn(batch_size, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) + q = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) # Generate K and V caches - # Add some extra pages to simulate a real scenario num_pages = total_pages_needed + 100 k_cache = torch.randn( - num_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + num_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) v_cache = torch.randn( - num_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + num_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) # Generate attention parameters - sm_scale = 1.0 / np.sqrt(head_dim) + sm_scale = 1.0 / np.sqrt(HEAD_DIM) sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) return { @@ -145,28 +77,14 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 8 - head_dim = 128 - page_size = 1 - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_seq_len, - num_attention_heads, - num_key_value_heads, - head_dim, - page_size, - device, - ) + inputs = generate_random_inputs(batch_size, max_seq_len, device) print(f"Generated sequences with lengths: {inputs['seq_lens'].cpu().numpy()}") print(f"Total pages used: {inputs['kv_indices'].shape[0]}") - # Run reference implementation - print("\nRunning reference implementation...") + # Run reference implementation from definition + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k_cache"], @@ -181,7 +99,7 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout="NHD" # Matches our cache layout + workspace_buffer, kv_layout="NHD" ) # Plan the attention computation @@ -189,10 +107,10 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): indptr=inputs["kv_indptr"], indices=inputs["kv_indices"], last_page_len=inputs["kv_last_page_len"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim=head_dim, - page_size=page_size, + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + page_size=PAGE_SIZE, pos_encoding_mode="NONE", q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, @@ -201,123 +119,33 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): # Run FlashInfer print("Running FlashInfer...") - # FlashInfer expects tuple of (k_cache, v_cache) for paged_kv_cache fi_output, fi_lse = decode_wrapper.run( inputs["q"], (inputs["k_cache"], inputs["v_cache"]), return_lse=True ) # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - # Convert to float32 for comparison - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - # Compute errors for output tensor - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - # Compute cosine similarity and MSE for output tensor - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - # Compare LSE values - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - # Check if outputs match within tolerance - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") + + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") else: print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - if not output_close: - # Find indices with largest errors for debugging - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - # Convert flat index back to 3D indices - batch_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}, {dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - # Find LSE errors - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - batch_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) - return all_close def main(): """Run comprehensive tests.""" - print("Testing Batch GQA Paged Decode Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, max_seq_len) - (1, 16), # Single batch - (4, 32), # Small batch - (8, 64), # Medium batch - (16, 128), # Large batch - ] + print("Testing Batch GQA Paged Decode Reference Implementation (from definition)") + + test_configs = [(1, 16), (4, 32), (8, 64), (16, 128)] passed = 0 total = len(test_configs) diff --git a/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv8_d128_ps64.py b/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv8_d128_ps64.py index 6e7ee062..78fe5bc8 100644 --- a/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv8_d128_ps64.py +++ b/flashinfer_trace/tests/references/test_gqa_paged_decode_h32_kv8_d128_ps64.py @@ -1,119 +1,32 @@ -import math +""" +Test GQA paged decode h32_kv8_d128_ps64 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" import flashinfer import numpy as np import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_paged_decode_h32_kv8_d128_ps64") -@torch.no_grad() -def run(q, k_cache, v_cache, kv_indptr, kv_indices, kv_last_page_len, sm_scale): - batch_size, num_qo_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_cache.shape - len_indptr = kv_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 8 - assert head_dim == 128 - assert page_size == 64 - - # Check constraints - assert len_indptr == batch_size + 1 - assert num_kv_indices == kv_indptr[-1].item() - - device = q.device - - output = torch.zeros((batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - k_cache_f32 = k_cache.to(torch.float32) - v_cache_f32 = v_cache.to(torch.float32) - - for b in range(batch_size): - page_start = int(kv_indptr[b].item()) - page_end = int(kv_indptr[b + 1].item()) - last_page_len = int(kv_last_page_len[b].item()) - - if page_start >= page_end: - output[b].zero_() - continue - - page_ids = kv_indices[page_start:page_end].to(torch.long) - num_pages_for_seq = page_ids.shape[0] - - if num_pages_for_seq == 0: - output[b].zero_() - continue - - num_full_pages = num_pages_for_seq - 1 - total_tokens = num_full_pages * page_size + last_page_len - - if total_tokens == 0: - output[b].zero_() - continue - - k_batch = torch.zeros( - (total_tokens, num_kv_heads, head_dim), dtype=torch.float32, device=device - ) - v_batch = torch.zeros( - (total_tokens, num_kv_heads, head_dim), dtype=torch.float32, device=device - ) - - token_idx = 0 - for p_idx, page_id in enumerate(page_ids): - if p_idx < num_full_pages: - k_batch[token_idx : token_idx + page_size] = k_cache_f32[page_id] - v_batch[token_idx : token_idx + page_size] = v_cache_f32[page_id] - token_idx += page_size - else: - k_batch[token_idx : token_idx + last_page_len] = k_cache_f32[ - page_id, :last_page_len - ] - v_batch[token_idx : token_idx + last_page_len] = v_cache_f32[ - page_id, :last_page_len - ] - token_idx += last_page_len - - q_batch = q[b].to(torch.float32) - - for h in range(num_qo_heads): - kv_head = h // gqa_ratio - - q_head = q_batch[h] - k_head = k_batch[:, kv_head] - v_head = v_batch[:, kv_head] - - logits = torch.matmul(q_head, k_head.T) - logits_scaled = logits * sm_scale - - lse[b, h] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) - out_head = torch.matmul(attn, v_head) - output[b, h] = out_head.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_seq_len, - num_attention_heads=32, - num_key_value_heads=8, - head_dim=128, - page_size=64, - device="cuda", -): - """Generate random inputs for testing.""" +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +PAGE_SIZE = 64 + +def generate_random_inputs(batch_size, max_seq_len, device="cuda"): + """Generate random inputs for testing.""" # Generate random sequence lengths for each batch seq_lens = torch.randint(1, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device) # Calculate pages needed for each sequence - pages_per_seq = (seq_lens + page_size - 1) // page_size # Ceiling division + pages_per_seq = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE # Ceiling division total_pages_needed = pages_per_seq.sum().item() # Generate kv_indptr based on pages per sequence @@ -124,22 +37,22 @@ def generate_random_inputs( kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) # Calculate last_page_len for each sequence - kv_last_page_len = ((seq_lens - 1) % page_size) + 1 + kv_last_page_len = ((seq_lens - 1) % PAGE_SIZE) + 1 # Generate query tensor - q = torch.randn(batch_size, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) + q = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) # Generate K and V caches num_pages = total_pages_needed + 100 k_cache = torch.randn( - num_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + num_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) v_cache = torch.randn( - num_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + num_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) # Generate attention parameters - sm_scale = 1.0 / np.sqrt(head_dim) + sm_scale = 1.0 / np.sqrt(HEAD_DIM) sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) return { @@ -165,29 +78,15 @@ def test_correctness(batch_size=4, max_seq_len=256, atol=1e-2, rtol=5e-2): print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 8 - head_dim = 128 - page_size = 64 - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_seq_len, - num_attention_heads, - num_key_value_heads, - head_dim, - page_size, - device, - ) + inputs = generate_random_inputs(batch_size, max_seq_len, device) print(f"Generated sequences with lengths: {inputs['seq_lens'].cpu().numpy()}") print(f"Last page lengths: {inputs['kv_last_page_len'].cpu().numpy()}") print(f"Total pages used: {inputs['kv_indices'].shape[0]}") - # Run reference implementation - print("\nRunning reference implementation...") + # Run reference implementation from definition (page_size=64 includes kv_last_page_len) + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k_cache"], @@ -211,10 +110,10 @@ def test_correctness(batch_size=4, max_seq_len=256, atol=1e-2, rtol=5e-2): indptr=inputs["kv_indptr"], indices=inputs["kv_indices"], last_page_len=inputs["kv_last_page_len"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim=head_dim, - page_size=page_size, + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + page_size=PAGE_SIZE, pos_encoding_mode="NONE", q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, @@ -229,99 +128,25 @@ def test_correctness(batch_size=4, max_seq_len=256, atol=1e-2, rtol=5e-2): # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") else: print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - if not output_close: - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - batch_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}, {dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - batch_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) - return all_close def main(): """Run comprehensive tests.""" - print("Testing Batch GQA Paged Decode Reference Implementation (page_size=64)") + print("Testing Batch GQA Paged Decode Reference Implementation (page_size=64, from definition)") test_configs = [(1, 64), (4, 128), (8, 256), (16, 512)] diff --git a/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv4_d128_ps1.py b/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv4_d128_ps1.py index 6755a9b5..848fad3c 100644 --- a/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv4_d128_ps1.py +++ b/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv4_d128_ps1.py @@ -1,137 +1,48 @@ +""" +Test GQA paged prefill h32_kv4_d128_ps1 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" + import math import flashinfer import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_paged_prefill_causal_h32_kv4_d128_ps1") -@torch.no_grad() -def run(q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale): - total_q, num_qo_heads, head_dim = q.shape - num_pages, page_size, num_kv_heads, _ = k_cache.shape - len_indptr = qo_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 4 - assert head_dim == 128 - assert page_size == 1 - - # Check constraints - assert total_q == qo_indptr[-1].item() - assert num_kv_indices == kv_indptr[-1].item() - - device = q.device - - output = torch.zeros((total_q, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((total_q, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - q_f32 = q.to(torch.float32) - # Flatten page dimension since page_size=1 - k_cache_flat = k_cache.squeeze(1).to(torch.float32) # [num_pages, num_kv_heads, head_dim] - v_cache_flat = v_cache.squeeze(1).to(torch.float32) # [num_pages, num_kv_heads, head_dim] - - for b in range(len_indptr - 1): - q_start = int(qo_indptr[b].item()) - q_end = int(qo_indptr[b + 1].item()) - - kv_start = int(kv_indptr[b].item()) - kv_end = int(kv_indptr[b + 1].item()) - - if q_start >= q_end or kv_start >= kv_end: - # No queries or KV for this batch element - continue - - page_ids = kv_indices[kv_start:kv_end].to(torch.long) - - # Number of KV tokens is equal to number of pages for page_size=1 - num_kv_tokens = page_ids.shape[0] - k_batch = k_cache_flat[page_ids] # [num_kv_tokens, num_kv_heads, head_dim] - v_batch = v_cache_flat[page_ids] # [num_kv_tokens, num_kv_heads, head_dim] - - # Get queries for this sequence - q_batch = q_f32[q_start:q_end] # [num_q_tokens, num_qo_heads, head_dim] - num_q_tokens = q_batch.shape[0] - - # Delta for causal masking - delta = num_kv_tokens - num_q_tokens - - for q_idx in range(num_q_tokens): - global_q_idx = q_start + q_idx - - # Apply causal mask - max_kv_idx = min(q_idx + 1 + delta, num_kv_tokens) - if max_kv_idx <= 0: - continue - - q_pos = q_batch[q_idx] # [num_qo_heads, head_dim] - - for h in range(num_qo_heads): - # Find corresponding KV head for GQA - kv_head = h // gqa_ratio - - q_head = q_pos[h] # [head_dim] - k_head = k_batch[:max_kv_idx, kv_head] # [max_kv_idx, head_dim] - v_head = v_batch[:max_kv_idx, kv_head] # [max_kv_idx, head_dim] - - logits = torch.matmul(q_head, k_head.T) # [max_kv_idx] - logits_scaled = logits * sm_scale - - # Compute 2-base LSE - lse[global_q_idx, h] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) # [max_kv_idx] - out_head = torch.matmul(attn, v_head) # [head_dim] - output[global_q_idx, h] = out_head.to(torch.bfloat16) - - return output, lse +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 4 +HEAD_DIM = 128 +PAGE_SIZE = 1 def generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - max_pages, - num_attention_heads=32, - num_key_value_heads=4, - head_dim=128, - page_size=1, - causal=True, - device="cuda", + batch_size, max_q_len, max_kv_len, max_pages, causal=True, device="cuda" ): """Generate random inputs for paged prefill testing.""" - - # Generate random query lengths for each batch element q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) - - # Generate random KV lengths for each batch element - # For prefill, KV length is typically >= query length (includes previous context) kv_lens = torch.zeros(batch_size, dtype=torch.int32) for i in range(batch_size): - # KV length should be at least as long as query length for causal attention if causal: kv_lens[i] = torch.randint(q_lens[i].item(), max_kv_len + 1, (1,)).item() else: kv_lens[i] = torch.randint(1, max_kv_len + 1, (1,)).item() - # Create indptr arrays qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) qo_indptr[1:] = torch.cumsum(q_lens.to(device), dim=0) kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) kv_indptr[1:] = torch.cumsum(kv_lens.to(device), dim=0) - # Get total tokens total_q = qo_indptr[-1].item() num_kv_indices = kv_indptr[-1].item() - # Generate page indices (for page_size=1, we need num_kv_indices unique pages) - # Simulate scattered memory allocation all_page_ids = torch.randperm(max_pages, device=device)[:num_kv_indices] - - # Create kv_indices by assigning pages to each sequence kv_indices = torch.zeros(num_kv_indices, dtype=torch.int32, device=device) idx = 0 for i in range(batch_size): @@ -139,25 +50,14 @@ def generate_random_inputs( kv_indices[idx : idx + seq_len] = all_page_ids[idx : idx + seq_len] idx += seq_len - # Generate KV cache (paged storage) k_cache = torch.randn( - max_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + max_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) v_cache = torch.randn( - max_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + max_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) - - # Generate query tensor - q = torch.randn(total_q, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) - - # Generate attention parameters - sm_scale = 1.0 / math.sqrt(head_dim) - sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) - - # Convert causal to tensor - causal = torch.tensor(causal, dtype=torch.bool, device=device) - - # For page_size=1, last_page_len is always all ones + q = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) + sm_scale = torch.tensor(1.0 / math.sqrt(HEAD_DIM), dtype=torch.float32, device=device) last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) return { @@ -174,7 +74,6 @@ def generate_random_inputs( "num_kv_indices": num_kv_indices, "sm_scale": sm_scale, "causal": causal, - "page_size": page_size, } @@ -182,7 +81,7 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato """Test correctness of paged prefill reference implementation against FlashInfer.""" print(f"\n{'='*60}") print( - f"Testing GQA Paged Prefill batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}, causal={causal}" + f"Testing GQA Paged Prefill batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}" ) print(f"{'='*60}") @@ -191,39 +90,13 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 4 - head_dim = 128 - page_size = 1 - - # Maximum number of pages (should be large enough to hold all KV tokens) - max_pages = max_kv_len * batch_size * 2 # Extra buffer for scattered allocation - - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - max_pages, - num_attention_heads, - num_key_value_heads, - head_dim, - page_size, - causal, - device, - ) + max_pages = max_kv_len * batch_size * 2 + inputs = generate_random_inputs(batch_size, max_q_len, max_kv_len, max_pages, causal, device) print(f"Generated query lengths: {inputs['q_lens'].cpu().numpy()}") print(f"Generated KV lengths: {inputs['kv_lens'].cpu().numpy()}") - print(f"Total query tokens: {inputs['total_q']}") - print(f"Total KV indices: {inputs['num_kv_indices']}") - print(f"Max page ID used: {inputs['kv_indices'].max().item()}") - print(f"Causal mode: {inputs['causal'].item()}") - print(f"Page size: {inputs['page_size']}") - # Run reference implementation - print("\nRunning reference implementation...") + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k_cache"], @@ -234,177 +107,54 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato inputs["sm_scale"], ) - # Setup FlashInfer print("\nSetting up FlashInfer...") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout="NHD" # Layout for K/V tensors + workspace_buffer, kv_layout="NHD" ) - - # Combine k_cache and v_cache into paged_kv_cache format that FlashInfer expects - # FlashInfer expects shape [max_num_pages, 2, page_size, num_kv_heads, head_dim] for NHD layout paged_kv_cache = torch.stack([inputs["k_cache"], inputs["v_cache"]], dim=1) - # Plan the attention computation prefill_wrapper.plan( qo_indptr=inputs["qo_indptr"], paged_kv_indptr=inputs["kv_indptr"], paged_kv_indices=inputs["kv_indices"], paged_kv_last_page_len=inputs["last_page_len"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim_qk=head_dim, - head_dim_vo=head_dim, - page_size=page_size, - causal=inputs["causal"].item(), + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim_qk=HEAD_DIM, + head_dim_vo=HEAD_DIM, + page_size=PAGE_SIZE, + causal=inputs["causal"], sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) - # Run FlashInfer print("Running FlashInfer...") fi_output, fi_lse = prefill_wrapper.run(inputs["q"], paged_kv_cache, return_lse=True) - # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - # Convert to float32 for comparison - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - # Compute errors for output tensor - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - # Compute cosine similarity and MSE for output tensor - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - # Compare LSE values - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - # Check if outputs match within tolerance - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: - print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - if not output_close: - # Find indices with largest errors for debugging - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - # Convert flat index back to 3D indices - q_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}, dim={dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - # Find LSE errors - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - q_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) + print(f"\n✗ FAILED: Outputs differ beyond tolerance") return all_close def main(): - """Run comprehensive tests.""" - print("Testing Batch GQA Paged Prefill Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, max_q_len, max_kv_len, causal) - (1, 8, 16, True), # Single batch, small, causal - # (1, 8, 16, False), # Single batch, small, non-causal - (4, 16, 32, True), # Small batch, causal - # (4, 16, 32, False), # Small batch, non-causal - (8, 32, 64, True), # Medium batch, causal - # (8, 32, 64, False), # Medium batch, non-causal - (16, 64, 128, True), # Large batch, causal - # (16, 64, 128, False), # Large batch, non-causal - ] - - passed = 0 - total = len(test_configs) - - for batch_size, max_q_len, max_kv_len, causal in test_configs: - try: - if test_correctness(batch_size, max_q_len, max_kv_len, causal): - passed += 1 - except Exception as e: - print(f"✗ Test failed with exception: {str(e)}") - import traceback - - traceback.print_exc() - - print(f"\n{'='*60}") - print(f"Summary: {passed}/{total} tests passed") - print(f"{'='*60}") - - if passed == total: - print("✓ All tests passed!") - else: - print(f"✗ {total - passed} tests failed") + print("Testing Batch GQA Paged Prefill Reference Implementation (from definition)") + test_configs = [(1, 8, 16, True), (4, 16, 32, True), (8, 32, 64, True), (16, 64, 128, True)] + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nSummary: {passed}/{len(test_configs)} tests passed\n{'='*60}") if __name__ == "__main__": diff --git a/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv4_d128_ps64.py b/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv4_d128_ps64.py index 409f6301..3aa26c10 100644 --- a/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv4_d128_ps64.py +++ b/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv4_d128_ps64.py @@ -1,131 +1,31 @@ +""" +Test GQA paged prefill h32_kv4_d128_ps64 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" + import math import flashinfer import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_paged_prefill_causal_h32_kv4_d128_ps64") -@torch.no_grad() -def run(q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, kv_last_page_len, sm_scale): - total_q, num_qo_heads, head_dim = q.shape - num_pages, page_size, num_kv_heads, _ = k_cache.shape - len_indptr = qo_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 4 - assert head_dim == 128 - assert page_size == 64 - - # Check constraints - assert total_q == qo_indptr[-1].item() - - device = q.device - batch_size = len_indptr - 1 - - output = torch.zeros((total_q, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((total_q, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - q_f32 = q.to(torch.float32) - k_cache_f32 = k_cache.to(torch.float32) - v_cache_f32 = v_cache.to(torch.float32) - - for b in range(batch_size): - q_start = int(qo_indptr[b].item()) - q_end = int(qo_indptr[b + 1].item()) - - kv_start = int(kv_indptr[b].item()) - kv_end = int(kv_indptr[b + 1].item()) - last_page_len = int(kv_last_page_len[b].item()) - - if q_start >= q_end or kv_start >= kv_end: - continue - - page_ids = kv_indices[kv_start:kv_end].to(torch.long) - num_pages_for_seq = page_ids.shape[0] - - # Calculate total KV tokens - num_full_pages = num_pages_for_seq - 1 - num_kv_tokens = num_full_pages * page_size + last_page_len - - # Gather K and V from pages - k_batch = torch.zeros( - (num_kv_tokens, num_kv_heads, head_dim), dtype=torch.float32, device=device - ) - v_batch = torch.zeros( - (num_kv_tokens, num_kv_heads, head_dim), dtype=torch.float32, device=device - ) - - token_idx = 0 - for p_idx, page_id in enumerate(page_ids): - if p_idx < num_full_pages: - k_batch[token_idx : token_idx + page_size] = k_cache_f32[page_id] - v_batch[token_idx : token_idx + page_size] = v_cache_f32[page_id] - token_idx += page_size - else: - k_batch[token_idx : token_idx + last_page_len] = k_cache_f32[ - page_id, :last_page_len - ] - v_batch[token_idx : token_idx + last_page_len] = v_cache_f32[ - page_id, :last_page_len - ] - token_idx += last_page_len - - q_batch = q_f32[q_start:q_end] - num_q_tokens = q_batch.shape[0] - - # Delta for causal masking - delta = num_kv_tokens - num_q_tokens - - for q_idx in range(num_q_tokens): - global_q_idx = q_start + q_idx - - # Apply causal mask - max_kv_idx = min(q_idx + 1 + delta, num_kv_tokens) - if max_kv_idx <= 0: - continue - - q_pos = q_batch[q_idx] - - for h in range(num_qo_heads): - kv_head = h // gqa_ratio - - q_head = q_pos[h] - k_head = k_batch[:max_kv_idx, kv_head] - v_head = v_batch[:max_kv_idx, kv_head] - - logits = torch.matmul(q_head, k_head.T) - logits_scaled = logits * sm_scale - - lse[global_q_idx, h] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) - out_head = torch.matmul(attn, v_head) - output[global_q_idx, h] = out_head.to(torch.bfloat16) - - return output, lse +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 4 +HEAD_DIM = 128 +PAGE_SIZE = 64 def generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - max_pages, - num_attention_heads=32, - num_key_value_heads=4, - head_dim=128, - page_size=64, - causal=True, - device="cuda", + batch_size, max_q_len, max_kv_len, max_pages, causal=True, device="cuda" ): - """Generate random inputs for paged prefill testing.""" - - # Generate random query lengths for each batch element + """Generate random inputs for paged prefill testing with page_size=64.""" q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) - - # Generate random KV lengths for each batch element kv_lens = torch.zeros(batch_size, dtype=torch.int32) for i in range(batch_size): if causal: @@ -133,45 +33,37 @@ def generate_random_inputs( else: kv_lens[i] = torch.randint(1, max_kv_len + 1, (1,)).item() - # Create qo_indptr qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) qo_indptr[1:] = torch.cumsum(q_lens.to(device), dim=0) # Calculate pages needed for each sequence - pages_per_seq = (kv_lens + page_size - 1) // page_size # Ceiling division - total_pages_needed = pages_per_seq.sum().item() - - # Create kv_indptr based on pages per sequence + pages_per_seq = (kv_lens + PAGE_SIZE - 1) // PAGE_SIZE kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) kv_indptr[1:] = torch.cumsum(pages_per_seq.to(device), dim=0) - # Generate page indices - kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) + total_q = qo_indptr[-1].item() + num_kv_pages = kv_indptr[-1].item() - # Calculate last_page_len for each sequence - kv_last_page_len = ((kv_lens - 1) % page_size) + 1 - kv_last_page_len = kv_last_page_len.to(device) + all_page_ids = torch.randperm(max_pages, device=device)[:num_kv_pages] + kv_indices = torch.zeros(num_kv_pages, dtype=torch.int32, device=device) + idx = 0 + for i in range(batch_size): + num_pages = pages_per_seq[i].item() + kv_indices[idx : idx + num_pages] = all_page_ids[idx : idx + num_pages] + idx += num_pages - # Get total tokens - total_q = qo_indptr[-1].item() + # Calculate last_page_len for each sequence + last_page_len = ((kv_lens - 1) % PAGE_SIZE) + 1 + last_page_len = last_page_len.to(torch.int32).to(device) - # Generate KV cache (paged storage) k_cache = torch.randn( - max_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + max_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) v_cache = torch.randn( - max_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + max_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) - - # Generate query tensor - q = torch.randn(total_q, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) - - # Generate attention parameters - sm_scale = 1.0 / math.sqrt(head_dim) - sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) - - # Convert causal to tensor - causal = torch.tensor(causal, dtype=torch.bool, device=device) + q = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) + sm_scale = torch.tensor(1.0 / math.sqrt(HEAD_DIM), dtype=torch.float32, device=device) return { "q": q, @@ -180,21 +72,20 @@ def generate_random_inputs( "qo_indptr": qo_indptr, "kv_indptr": kv_indptr, "kv_indices": kv_indices, - "kv_last_page_len": kv_last_page_len, + "last_page_len": last_page_len, "q_lens": q_lens, "kv_lens": kv_lens, "total_q": total_q, "sm_scale": sm_scale, "causal": causal, - "page_size": page_size, } -def test_correctness(batch_size=4, max_q_len=32, max_kv_len=128, causal=True, atol=1e-2, rtol=5e-2): +def test_correctness(batch_size=4, max_q_len=32, max_kv_len=256, causal=True, atol=1e-2, rtol=5e-2): """Test correctness of paged prefill reference implementation against FlashInfer.""" print(f"\n{'='*60}") print( - f"Testing GQA Paged Prefill batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}, causal={causal}" + f"Testing GQA Paged Prefill (ps64) batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}" ) print(f"{'='*60}") @@ -203,39 +94,14 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=128, causal=True, at print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 4 - head_dim = 128 - page_size = 64 - - # Maximum number of pages - max_pages = (max_kv_len * batch_size * 2 + page_size - 1) // page_size + 100 - - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - max_pages, - num_attention_heads, - num_key_value_heads, - head_dim, - page_size, - causal, - device, - ) + max_pages = (max_kv_len * batch_size * 2) // PAGE_SIZE + 100 + inputs = generate_random_inputs(batch_size, max_q_len, max_kv_len, max_pages, causal, device) print(f"Generated query lengths: {inputs['q_lens'].cpu().numpy()}") print(f"Generated KV lengths: {inputs['kv_lens'].cpu().numpy()}") - print(f"Last page lengths: {inputs['kv_last_page_len'].cpu().numpy()}") - print(f"Total query tokens: {inputs['total_q']}") - print(f"Total pages: {inputs['kv_indices'].shape[0]}") - print(f"Causal mode: {inputs['causal'].item()}") - print(f"Page size: {inputs['page_size']}") - - # Run reference implementation - print("\nRunning reference implementation...") + print(f"Last page lengths: {inputs['last_page_len'].cpu().numpy()}") + + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k_cache"], @@ -243,161 +109,60 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=128, causal=True, at inputs["qo_indptr"], inputs["kv_indptr"], inputs["kv_indices"], - inputs["kv_last_page_len"], + inputs["last_page_len"], inputs["sm_scale"], ) - # Setup FlashInfer print("\nSetting up FlashInfer...") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout="NHD" ) - - # Combine k_cache and v_cache into paged_kv_cache format paged_kv_cache = torch.stack([inputs["k_cache"], inputs["v_cache"]], dim=1) - # Plan the attention computation prefill_wrapper.plan( qo_indptr=inputs["qo_indptr"], paged_kv_indptr=inputs["kv_indptr"], paged_kv_indices=inputs["kv_indices"], - paged_kv_last_page_len=inputs["kv_last_page_len"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim_qk=head_dim, - head_dim_vo=head_dim, - page_size=page_size, - causal=inputs["causal"].item(), + paged_kv_last_page_len=inputs["last_page_len"], + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim_qk=HEAD_DIM, + head_dim_vo=HEAD_DIM, + page_size=PAGE_SIZE, + causal=inputs["causal"], sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) - # Run FlashInfer print("Running FlashInfer...") fi_output, fi_lse = prefill_wrapper.run(inputs["q"], paged_kv_cache, return_lse=True) - # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: - print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - if not output_close: - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - q_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}, dim={dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - q_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) + print(f"\n✗ FAILED: Outputs differ beyond tolerance") return all_close def main(): - """Run comprehensive tests.""" - print("Testing Batch GQA Paged Prefill Reference Implementation (page_size=64)") - - test_configs = [(1, 16, 64, True), (4, 32, 128, True), (8, 64, 256, True), (16, 128, 512, True)] - - passed = 0 - total = len(test_configs) - - for batch_size, max_q_len, max_kv_len, causal in test_configs: - try: - if test_correctness(batch_size, max_q_len, max_kv_len, causal): - passed += 1 - except Exception as e: - print(f"✗ Test failed with exception: {str(e)}") - import traceback - - traceback.print_exc() - - print(f"\n{'='*60}") - print(f"Summary: {passed}/{total} tests passed") - print(f"{'='*60}") - - if passed == total: - print("✓ All tests passed!") - else: - print(f"✗ {total - passed} tests failed") + print( + "Testing Batch GQA Paged Prefill Reference Implementation (page_size=64, from definition)" + ) + test_configs = [(1, 16, 64, True), (4, 32, 128, True), (8, 64, 256, True)] + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nSummary: {passed}/{len(test_configs)} tests passed\n{'='*60}") if __name__ == "__main__": diff --git a/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv8_d128_ps1.py b/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv8_d128_ps1.py index 4c5ad6f7..c8f0b5a1 100644 --- a/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv8_d128_ps1.py +++ b/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv8_d128_ps1.py @@ -1,105 +1,28 @@ +""" +Test GQA paged prefill h32_kv8_d128_ps1 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" + import math import flashinfer import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_paged_prefill_causal_h32_kv8_d128_ps1") -@torch.no_grad() -def run(q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale): - total_q, num_qo_heads, head_dim = q.shape - num_pages, page_size, num_kv_heads, _ = k_cache.shape - len_indptr = qo_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 8 - assert head_dim == 128 - assert page_size == 1 - - # Check constraints - assert total_q == qo_indptr[-1].item() - assert num_kv_indices == kv_indptr[-1].item() - - device = q.device - - output = torch.zeros((total_q, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((total_q, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - q_f32 = q.to(torch.float32) - # Flatten page dimension since page_size=1 - k_cache_flat = k_cache.squeeze(1).to(torch.float32) # [num_pages, num_kv_heads, head_dim] - v_cache_flat = v_cache.squeeze(1).to(torch.float32) # [num_pages, num_kv_heads, head_dim] - - for b in range(len_indptr - 1): - q_start = int(qo_indptr[b].item()) - q_end = int(qo_indptr[b + 1].item()) - - kv_start = int(kv_indptr[b].item()) - kv_end = int(kv_indptr[b + 1].item()) - - if q_start >= q_end or kv_start >= kv_end: - # No queries or KV for this batch element - continue - - page_ids = kv_indices[kv_start:kv_end].to(torch.long) - - # Number of KV tokens is equal to number of pages for page_size=1 - num_kv_tokens = page_ids.shape[0] - k_batch = k_cache_flat[page_ids] # [num_kv_tokens, num_kv_heads, head_dim] - v_batch = v_cache_flat[page_ids] # [num_kv_tokens, num_kv_heads, head_dim] - - # Get queries for this sequence - q_batch = q_f32[q_start:q_end] # [num_q_tokens, num_qo_heads, head_dim] - num_q_tokens = q_batch.shape[0] - - # Delta for causal masking - delta = num_kv_tokens - num_q_tokens - - for q_idx in range(num_q_tokens): - global_q_idx = q_start + q_idx - - # Apply causal mask - max_kv_idx = min(q_idx + 1 + delta, num_kv_tokens) - if max_kv_idx <= 0: - continue - - q_pos = q_batch[q_idx] # [num_qo_heads, head_dim] - - for h in range(num_qo_heads): - # Find corresponding KV head for GQA - kv_head = h // gqa_ratio - - q_head = q_pos[h] # [head_dim] - k_head = k_batch[:max_kv_idx, kv_head] # [max_kv_idx, head_dim] - v_head = v_batch[:max_kv_idx, kv_head] # [max_kv_idx, head_dim] - - logits = torch.matmul(q_head, k_head.T) # [max_kv_idx] - logits_scaled = logits * sm_scale - - # Compute 2-base LSE - lse[global_q_idx, h] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) # [max_kv_idx] - out_head = torch.matmul(attn, v_head) # [head_dim] - output[global_q_idx, h] = out_head.to(torch.bfloat16) - - return output, lse +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +PAGE_SIZE = 1 def generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - max_pages, - num_attention_heads=32, - num_key_value_heads=8, - head_dim=128, - page_size=1, - causal=True, - device="cuda", + batch_size, max_q_len, max_kv_len, max_pages, causal=True, device="cuda" ): """Generate random inputs for paged prefill testing.""" @@ -107,10 +30,8 @@ def generate_random_inputs( q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) # Generate random KV lengths for each batch element - # For prefill, KV length is typically >= query length (includes previous context) kv_lens = torch.zeros(batch_size, dtype=torch.int32) for i in range(batch_size): - # KV length should be at least as long as query length for causal attention if causal: kv_lens[i] = torch.randint(q_lens[i].item(), max_kv_len + 1, (1,)).item() else: @@ -127,8 +48,7 @@ def generate_random_inputs( total_q = qo_indptr[-1].item() num_kv_indices = kv_indptr[-1].item() - # Generate page indices (for page_size=1, we need num_kv_indices unique pages) - # Simulate scattered memory allocation + # Generate page indices all_page_ids = torch.randperm(max_pages, device=device)[:num_kv_indices] # Create kv_indices by assigning pages to each sequence @@ -139,24 +59,21 @@ def generate_random_inputs( kv_indices[idx : idx + seq_len] = all_page_ids[idx : idx + seq_len] idx += seq_len - # Generate KV cache (paged storage) + # Generate KV cache k_cache = torch.randn( - max_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + max_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) v_cache = torch.randn( - max_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + max_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) # Generate query tensor - q = torch.randn(total_q, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) + q = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) # Generate attention parameters - sm_scale = 1.0 / math.sqrt(head_dim) + sm_scale = 1.0 / math.sqrt(HEAD_DIM) sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) - # Convert causal to tensor - causal = torch.tensor(causal, dtype=torch.bool, device=device) - # For page_size=1, last_page_len is always all ones last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) @@ -174,7 +91,6 @@ def generate_random_inputs( "num_kv_indices": num_kv_indices, "sm_scale": sm_scale, "causal": causal, - "page_size": page_size, } @@ -191,39 +107,19 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 8 - head_dim = 128 - page_size = 1 - - # Maximum number of pages (should be large enough to hold all KV tokens) - max_pages = max_kv_len * batch_size * 2 # Extra buffer for scattered allocation + # Maximum number of pages + max_pages = max_kv_len * batch_size * 2 # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - max_pages, - num_attention_heads, - num_key_value_heads, - head_dim, - page_size, - causal, - device, - ) + inputs = generate_random_inputs(batch_size, max_q_len, max_kv_len, max_pages, causal, device) print(f"Generated query lengths: {inputs['q_lens'].cpu().numpy()}") print(f"Generated KV lengths: {inputs['kv_lens'].cpu().numpy()}") print(f"Total query tokens: {inputs['total_q']}") print(f"Total KV indices: {inputs['num_kv_indices']}") - print(f"Max page ID used: {inputs['kv_indices'].max().item()}") - print(f"Causal mode: {inputs['causal'].item()}") - print(f"Page size: {inputs['page_size']}") - # Run reference implementation - print("\nRunning reference implementation...") + # Run reference implementation from definition + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k_cache"], @@ -239,25 +135,22 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout="NHD" # Layout for K/V tensors + workspace_buffer, kv_layout="NHD" ) - # Combine k_cache and v_cache into paged_kv_cache format that FlashInfer expects - # FlashInfer expects shape [max_num_pages, 2, page_size, num_kv_heads, head_dim] for NHD layout paged_kv_cache = torch.stack([inputs["k_cache"], inputs["v_cache"]], dim=1) - # Plan the attention computation prefill_wrapper.plan( qo_indptr=inputs["qo_indptr"], paged_kv_indptr=inputs["kv_indptr"], paged_kv_indices=inputs["kv_indices"], paged_kv_last_page_len=inputs["last_page_len"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim_qk=head_dim, - head_dim_vo=head_dim, - page_size=page_size, - causal=inputs["causal"].item(), + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim_qk=HEAD_DIM, + head_dim_vo=HEAD_DIM, + page_size=PAGE_SIZE, + causal=inputs["causal"], sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, @@ -269,120 +162,27 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - # Convert to float32 for comparison - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - # Compute errors for output tensor - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - # Compute cosine similarity and MSE for output tensor - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - # Compare LSE values - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - # Check if outputs match within tolerance - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") + + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") else: print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - if not output_close: - # Find indices with largest errors for debugging - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - # Convert flat index back to 3D indices - q_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}, dim={dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - # Find LSE errors - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - q_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) - return all_close def main(): """Run comprehensive tests.""" - print("Testing Batch GQA Paged Prefill Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, max_q_len, max_kv_len, causal) - (1, 8, 16, True), # Single batch, small, causal - # (1, 8, 16, False), # Single batch, small, non-causal - (4, 16, 32, True), # Small batch, causal - # (4, 16, 32, False), # Small batch, non-causal - (8, 32, 64, True), # Medium batch, causal - # (8, 32, 64, False), # Medium batch, non-causal - (16, 64, 128, True), # Large batch, causal - # (16, 64, 128, False), # Large batch, non-causal - ] + print("Testing Batch GQA Paged Prefill Reference Implementation (from definition)") + + test_configs = [(1, 8, 16, True), (4, 16, 32, True), (8, 32, 64, True), (16, 64, 128, True)] passed = 0 total = len(test_configs) diff --git a/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv8_d128_ps64.py b/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv8_d128_ps64.py index 962d30ff..b259d07b 100644 --- a/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv8_d128_ps64.py +++ b/flashinfer_trace/tests/references/test_gqa_paged_prefill_h32_kv8_d128_ps64.py @@ -1,131 +1,31 @@ +""" +Test GQA paged prefill h32_kv8_d128_ps64 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" + import math import flashinfer import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_paged_prefill_causal_h32_kv8_d128_ps64") -@torch.no_grad() -def run(q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, kv_last_page_len, sm_scale): - total_q, num_qo_heads, head_dim = q.shape - num_pages, page_size, num_kv_heads, _ = k_cache.shape - len_indptr = qo_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 8 - assert head_dim == 128 - assert page_size == 64 - - # Check constraints - assert total_q == qo_indptr[-1].item() - - device = q.device - batch_size = len_indptr - 1 - - output = torch.zeros((total_q, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((total_q, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - q_f32 = q.to(torch.float32) - k_cache_f32 = k_cache.to(torch.float32) - v_cache_f32 = v_cache.to(torch.float32) - - for b in range(batch_size): - q_start = int(qo_indptr[b].item()) - q_end = int(qo_indptr[b + 1].item()) - - kv_start = int(kv_indptr[b].item()) - kv_end = int(kv_indptr[b + 1].item()) - last_page_len = int(kv_last_page_len[b].item()) - - if q_start >= q_end or kv_start >= kv_end: - continue - - page_ids = kv_indices[kv_start:kv_end].to(torch.long) - num_pages_for_seq = page_ids.shape[0] - - # Calculate total KV tokens - num_full_pages = num_pages_for_seq - 1 - num_kv_tokens = num_full_pages * page_size + last_page_len - - # Gather K and V from pages - k_batch = torch.zeros( - (num_kv_tokens, num_kv_heads, head_dim), dtype=torch.float32, device=device - ) - v_batch = torch.zeros( - (num_kv_tokens, num_kv_heads, head_dim), dtype=torch.float32, device=device - ) - - token_idx = 0 - for p_idx, page_id in enumerate(page_ids): - if p_idx < num_full_pages: - k_batch[token_idx : token_idx + page_size] = k_cache_f32[page_id] - v_batch[token_idx : token_idx + page_size] = v_cache_f32[page_id] - token_idx += page_size - else: - k_batch[token_idx : token_idx + last_page_len] = k_cache_f32[ - page_id, :last_page_len - ] - v_batch[token_idx : token_idx + last_page_len] = v_cache_f32[ - page_id, :last_page_len - ] - token_idx += last_page_len - - q_batch = q_f32[q_start:q_end] - num_q_tokens = q_batch.shape[0] - - # Delta for causal masking - delta = num_kv_tokens - num_q_tokens - - for q_idx in range(num_q_tokens): - global_q_idx = q_start + q_idx - - # Apply causal mask - max_kv_idx = min(q_idx + 1 + delta, num_kv_tokens) - if max_kv_idx <= 0: - continue - - q_pos = q_batch[q_idx] - - for h in range(num_qo_heads): - kv_head = h // gqa_ratio - - q_head = q_pos[h] - k_head = k_batch[:max_kv_idx, kv_head] - v_head = v_batch[:max_kv_idx, kv_head] - - logits = torch.matmul(q_head, k_head.T) - logits_scaled = logits * sm_scale - - lse[global_q_idx, h] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) - out_head = torch.matmul(attn, v_head) - output[global_q_idx, h] = out_head.to(torch.bfloat16) - - return output, lse +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +PAGE_SIZE = 64 def generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - max_pages, - num_attention_heads=32, - num_key_value_heads=8, - head_dim=128, - page_size=64, - causal=True, - device="cuda", + batch_size, max_q_len, max_kv_len, max_pages, causal=True, device="cuda" ): - """Generate random inputs for paged prefill testing.""" - - # Generate random query lengths for each batch element + """Generate random inputs for paged prefill testing with page_size=64.""" q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) - - # Generate random KV lengths for each batch element kv_lens = torch.zeros(batch_size, dtype=torch.int32) for i in range(batch_size): if causal: @@ -133,45 +33,35 @@ def generate_random_inputs( else: kv_lens[i] = torch.randint(1, max_kv_len + 1, (1,)).item() - # Create qo_indptr qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) qo_indptr[1:] = torch.cumsum(q_lens.to(device), dim=0) - # Calculate pages needed for each sequence - pages_per_seq = (kv_lens + page_size - 1) // page_size # Ceiling division - total_pages_needed = pages_per_seq.sum().item() - - # Create kv_indptr based on pages per sequence + pages_per_seq = (kv_lens + PAGE_SIZE - 1) // PAGE_SIZE kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) kv_indptr[1:] = torch.cumsum(pages_per_seq.to(device), dim=0) - # Generate page indices - kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) + total_q = qo_indptr[-1].item() + num_kv_pages = kv_indptr[-1].item() - # Calculate last_page_len for each sequence - kv_last_page_len = ((kv_lens - 1) % page_size) + 1 - kv_last_page_len = kv_last_page_len.to(device) + all_page_ids = torch.randperm(max_pages, device=device)[:num_kv_pages] + kv_indices = torch.zeros(num_kv_pages, dtype=torch.int32, device=device) + idx = 0 + for i in range(batch_size): + num_pages = pages_per_seq[i].item() + kv_indices[idx : idx + num_pages] = all_page_ids[idx : idx + num_pages] + idx += num_pages - # Get total tokens - total_q = qo_indptr[-1].item() + last_page_len = ((kv_lens - 1) % PAGE_SIZE) + 1 + last_page_len = last_page_len.to(torch.int32).to(device) - # Generate KV cache (paged storage) k_cache = torch.randn( - max_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + max_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) v_cache = torch.randn( - max_pages, page_size, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device + max_pages, PAGE_SIZE, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device ) - - # Generate query tensor - q = torch.randn(total_q, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) - - # Generate attention parameters - sm_scale = 1.0 / math.sqrt(head_dim) - sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) - - # Convert causal to tensor - causal = torch.tensor(causal, dtype=torch.bool, device=device) + q = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) + sm_scale = torch.tensor(1.0 / math.sqrt(HEAD_DIM), dtype=torch.float32, device=device) return { "q": q, @@ -180,21 +70,20 @@ def generate_random_inputs( "qo_indptr": qo_indptr, "kv_indptr": kv_indptr, "kv_indices": kv_indices, - "kv_last_page_len": kv_last_page_len, + "last_page_len": last_page_len, "q_lens": q_lens, "kv_lens": kv_lens, "total_q": total_q, "sm_scale": sm_scale, "causal": causal, - "page_size": page_size, } -def test_correctness(batch_size=4, max_q_len=32, max_kv_len=128, causal=True, atol=1e-2, rtol=5e-2): +def test_correctness(batch_size=4, max_q_len=32, max_kv_len=256, causal=True, atol=1e-2, rtol=5e-2): """Test correctness of paged prefill reference implementation against FlashInfer.""" print(f"\n{'='*60}") print( - f"Testing GQA Paged Prefill batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}, causal={causal}" + f"Testing GQA Paged Prefill (ps64) batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}" ) print(f"{'='*60}") @@ -203,39 +92,13 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=128, causal=True, at print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 8 - head_dim = 128 - page_size = 64 - - # Maximum number of pages - max_pages = (max_kv_len * batch_size * 2 + page_size - 1) // page_size + 100 - - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - max_pages, - num_attention_heads, - num_key_value_heads, - head_dim, - page_size, - causal, - device, - ) + max_pages = (max_kv_len * batch_size * 2) // PAGE_SIZE + 100 + inputs = generate_random_inputs(batch_size, max_q_len, max_kv_len, max_pages, causal, device) print(f"Generated query lengths: {inputs['q_lens'].cpu().numpy()}") print(f"Generated KV lengths: {inputs['kv_lens'].cpu().numpy()}") - print(f"Last page lengths: {inputs['kv_last_page_len'].cpu().numpy()}") - print(f"Total query tokens: {inputs['total_q']}") - print(f"Total pages: {inputs['kv_indices'].shape[0]}") - print(f"Causal mode: {inputs['causal'].item()}") - print(f"Page size: {inputs['page_size']}") - - # Run reference implementation - print("\nRunning reference implementation...") + + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k_cache"], @@ -243,161 +106,60 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=128, causal=True, at inputs["qo_indptr"], inputs["kv_indptr"], inputs["kv_indices"], - inputs["kv_last_page_len"], + inputs["last_page_len"], inputs["sm_scale"], ) - # Setup FlashInfer print("\nSetting up FlashInfer...") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout="NHD" ) - - # Combine k_cache and v_cache into paged_kv_cache format paged_kv_cache = torch.stack([inputs["k_cache"], inputs["v_cache"]], dim=1) - # Plan the attention computation prefill_wrapper.plan( qo_indptr=inputs["qo_indptr"], paged_kv_indptr=inputs["kv_indptr"], paged_kv_indices=inputs["kv_indices"], - paged_kv_last_page_len=inputs["kv_last_page_len"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim_qk=head_dim, - head_dim_vo=head_dim, - page_size=page_size, - causal=inputs["causal"].item(), + paged_kv_last_page_len=inputs["last_page_len"], + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim_qk=HEAD_DIM, + head_dim_vo=HEAD_DIM, + page_size=PAGE_SIZE, + causal=inputs["causal"], sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) - # Run FlashInfer print("Running FlashInfer...") fi_output, fi_lse = prefill_wrapper.run(inputs["q"], paged_kv_cache, return_lse=True) - # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: - print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - if not output_close: - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - q_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}, dim={dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - q_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) + print(f"\n✗ FAILED: Outputs differ beyond tolerance") return all_close def main(): - """Run comprehensive tests.""" - print("Testing Batch GQA Paged Prefill Reference Implementation (page_size=64)") - - test_configs = [(1, 16, 64, True), (4, 32, 128, True), (8, 64, 256, True), (16, 128, 512, True)] - - passed = 0 - total = len(test_configs) - - for batch_size, max_q_len, max_kv_len, causal in test_configs: - try: - if test_correctness(batch_size, max_q_len, max_kv_len, causal): - passed += 1 - except Exception as e: - print(f"✗ Test failed with exception: {str(e)}") - import traceback - - traceback.print_exc() - - print(f"\n{'='*60}") - print(f"Summary: {passed}/{total} tests passed") - print(f"{'='*60}") - - if passed == total: - print("✓ All tests passed!") - else: - print(f"✗ {total - passed} tests failed") + print( + "Testing Batch GQA Paged Prefill Reference Implementation (page_size=64, from definition)" + ) + test_configs = [(1, 16, 64, True), (4, 32, 128, True), (8, 64, 256, True)] + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nSummary: {passed}/{len(test_configs)} tests passed\n{'='*60}") if __name__ == "__main__": diff --git a/flashinfer_trace/tests/references/test_gqa_ragged_prefill_h32_kv4_d128.py b/flashinfer_trace/tests/references/test_gqa_ragged_prefill_h32_kv4_d128.py index 98e70ec1..dbbf5ae7 100644 --- a/flashinfer_trace/tests/references/test_gqa_ragged_prefill_h32_kv4_d128.py +++ b/flashinfer_trace/tests/references/test_gqa_ragged_prefill_h32_kv4_d128.py @@ -1,124 +1,48 @@ +""" +Test GQA ragged prefill h32_kv4_d128 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" + import math import flashinfer import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_ragged_prefill_causal_h32_kv4_d128") -@torch.no_grad() -def run(q, k, v, qo_indptr, kv_indptr, sm_scale): - total_q, num_qo_heads, head_dim = q.shape - total_kv, num_kv_heads, _ = k.shape - len_indptr = qo_indptr.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 4 - assert head_dim == 128 - - # Check constraints - assert total_q == qo_indptr[-1].item() - assert total_kv == kv_indptr[-1].item() - - device = q.device - - output = torch.zeros((total_q, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((total_q, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - q_f32 = q.to(torch.float32) - k_f32 = k.to(torch.float32) - v_f32 = v.to(torch.float32) - - for b in range(len_indptr - 1): - q_start = int(qo_indptr[b].item()) - q_end = int(qo_indptr[b + 1].item()) - - kv_start = int(kv_indptr[b].item()) - kv_end = int(kv_indptr[b + 1].item()) - - if q_start >= q_end or kv_start >= kv_end: - # No queries or KV for this batch element - continue - - # Get Q, K, V for this batch - q_batch = q_f32[q_start:q_end] # [num_q_tokens, num_qo_heads, head_dim] - k_batch = k_f32[kv_start:kv_end] # [num_kv_tokens, num_kv_heads, head_dim] - v_batch = v_f32[kv_start:kv_end] # [num_kv_tokens, num_kv_heads, head_dim] - - num_q_tokens = q_batch.shape[0] - num_kv_tokens = k_batch.shape[0] - delta = num_kv_tokens - num_q_tokens +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 4 +HEAD_DIM = 128 - k_expanded = k_batch.repeat_interleave(gqa_ratio, dim=1) - v_expanded = v_batch.repeat_interleave(gqa_ratio, dim=1) - # Compute attention scores: Q @ K^T - logits = torch.einsum("qhd,khd->qhk", q_batch, k_expanded) * sm_scale - - # For position q_idx, can attend to KV positions [0, min(q_idx + 1 + delta, num_kv_tokens)) - q_positions = torch.arange(num_q_tokens, device=device) # [num_q_tokens] - kv_positions = torch.arange(num_kv_tokens, device=device) # [num_kv_tokens] - - # Apply causal mask - causal_mask = kv_positions[None, :] < (q_positions[:, None] + 1 + delta) - logits = logits.masked_fill(~causal_mask[:, None, :], float("-inf")) - - # Compute 2-base LSE - lse_batch = torch.logsumexp(logits, dim=-1) / math.log(2.0) - lse[q_start:q_end] = lse_batch - - attn_weights = torch.softmax(logits, dim=-1) # [num_q_tokens, num_qo_heads, num_kv_tokens] - output_batch = torch.einsum("qhk,khd->qhd", attn_weights, v_expanded) - output[q_start:q_end] = output_batch.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - num_attention_heads=32, - num_key_value_heads=4, - head_dim=128, - causal=True, - device="cuda", -): +def generate_random_inputs(batch_size, max_q_len, max_kv_len, causal=True, device="cuda"): """Generate random inputs for ragged prefill testing.""" - - # Generate random query lengths for each batch element q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) - - # Generate random KV lengths for each batch element - # For prefill, KV length is typically >= query length (includes previous context) kv_lens = torch.zeros(batch_size, dtype=torch.int32) for i in range(batch_size): - # KV length should be at least as long as query length for causal attention - kv_lens[i] = torch.randint(q_lens[i].item(), max_kv_len + 1, (1,)).item() + if causal: + kv_lens[i] = torch.randint(q_lens[i].item(), max_kv_len + 1, (1,)).item() + else: + kv_lens[i] = torch.randint(1, max_kv_len + 1, (1,)).item() - # Create indptr arrays qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) qo_indptr[1:] = torch.cumsum(q_lens.to(device), dim=0) kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) kv_indptr[1:] = torch.cumsum(kv_lens.to(device), dim=0) - # Get total tokens total_q = qo_indptr[-1].item() total_kv = kv_indptr[-1].item() - # Generate tensors - q = torch.randn(total_q, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) - k = torch.randn(total_kv, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device) - v = torch.randn(total_kv, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device) - - # Generate attention parameters - sm_scale = 1.0 / math.sqrt(head_dim) - sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) - - # Convert causal to tensor - causal = torch.tensor(causal, dtype=torch.bool, device=device) + q = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) + k = torch.randn(total_kv, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) + v = torch.randn(total_kv, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) + sm_scale = torch.tensor(1.0 / math.sqrt(HEAD_DIM), dtype=torch.float32, device=device) return { "q": q, @@ -139,7 +63,7 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato """Test correctness of ragged prefill reference implementation against FlashInfer.""" print(f"\n{'='*60}") print( - f"Testing GQA Ragged Prefill batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}, causal={causal}" + f"Testing GQA Ragged Prefill batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}" ) print(f"{'='*60}") @@ -148,31 +72,12 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 4 - head_dim = 128 - - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - num_attention_heads, - num_key_value_heads, - head_dim, - causal, - device, - ) + inputs = generate_random_inputs(batch_size, max_q_len, max_kv_len, causal, device) print(f"Generated query lengths: {inputs['q_lens'].cpu().numpy()}") print(f"Generated KV lengths: {inputs['kv_lens'].cpu().numpy()}") - print(f"Total query tokens: {inputs['total_q']}") - print(f"Total KV tokens: {inputs['total_kv']}") - print(f"Causal mode: {inputs['causal'].item()}") - # Run reference implementation - print("\nRunning reference implementation...") + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k"], @@ -182,172 +87,50 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato inputs["sm_scale"], ) - # Setup FlashInfer print("\nSetting up FlashInfer...") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffer, kv_layout="NHD" # Layout for K/V tensors + workspace_buffer, kv_layout="NHD" ) - # Plan the attention computation prefill_wrapper.plan( qo_indptr=inputs["qo_indptr"], kv_indptr=inputs["kv_indptr"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim_qk=head_dim, # head dimension for query/key - head_dim_vo=head_dim, # head dimension for value/output (same as qk for standard attention) - causal=inputs["causal"].item(), # Use the randomly generated causal flag - sm_scale=inputs["sm_scale"], # Scale factor for softmax + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim_qk=HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=inputs["causal"], + sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) - # Run FlashInfer print("Running FlashInfer...") fi_output, fi_lse = prefill_wrapper.run(inputs["q"], inputs["k"], inputs["v"], return_lse=True) - # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - # Convert to float32 for comparison - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - # Compute errors for output tensor - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - # Compute cosine similarity and MSE for output tensor - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - # Compare LSE values - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - # Check if outputs match within tolerance - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: - print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - if not output_close: - # Find indices with largest errors for debugging - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - # Convert flat index back to 3D indices - q_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}, dim={dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - # Find LSE errors - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - q_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) + print(f"\n✗ FAILED: Outputs differ beyond tolerance") return all_close def main(): - """Run comprehensive tests.""" - print("Testing Batch GQA Ragged Prefill Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, max_q_len, max_kv_len, causal) - (1, 8, 16, True), # Single batch, small, causal - # (1, 8, 16, False), # Single batch, small, non-causal - (4, 16, 32, True), # Small batch, causal - # (4, 16, 32, False), # Small batch, non-causal - (8, 32, 64, True), # Medium batch, causal - # (8, 32, 64, False), # Medium batch, non-causal - (16, 64, 128, True), # Large batch, causal - # (16, 64, 128, False), # Large batch, non-causal - (32, 128, 256, True), # Very large batch, causal - # (32, 128, 256, False), # Very large batch, non-causal - ] - - passed = 0 - total = len(test_configs) - - for batch_size, max_q_len, max_kv_len, causal in test_configs: - try: - if test_correctness(batch_size, max_q_len, max_kv_len, causal): - passed += 1 - except Exception as e: - print(f"✗ Test failed with exception: {str(e)}") - import traceback - - traceback.print_exc() - - print(f"\n{'='*60}") - print(f"Summary: {passed}/{total} tests passed") - print(f"{'='*60}") - - if passed == total: - print("✓ All tests passed!") - else: - print(f"✗ {total - passed} tests failed") + print("Testing Batch GQA Ragged Prefill Reference Implementation (from definition)") + test_configs = [(1, 8, 16, True), (4, 16, 32, True), (8, 32, 64, True), (16, 64, 128, True)] + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nSummary: {passed}/{len(test_configs)} tests passed\n{'='*60}") if __name__ == "__main__": diff --git a/flashinfer_trace/tests/references/test_gqa_ragged_prefill_h32_kv8_d128.py b/flashinfer_trace/tests/references/test_gqa_ragged_prefill_h32_kv8_d128.py index 788ddfce..aeba031e 100644 --- a/flashinfer_trace/tests/references/test_gqa_ragged_prefill_h32_kv8_d128.py +++ b/flashinfer_trace/tests/references/test_gqa_ragged_prefill_h32_kv8_d128.py @@ -1,124 +1,48 @@ +""" +Test GQA ragged prefill h32_kv8_d128 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" + import math import flashinfer import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("gqa_ragged_prefill_causal_h32_kv8_d128") -@torch.no_grad() -def run(q, k, v, qo_indptr, kv_indptr, sm_scale): - total_q, num_qo_heads, head_dim = q.shape - total_kv, num_kv_heads, _ = k.shape - len_indptr = qo_indptr.shape[0] - - # Check constants - assert num_qo_heads == 32 - assert num_kv_heads == 8 - assert head_dim == 128 - - # Check constraints - assert total_q == qo_indptr[-1].item() - assert total_kv == kv_indptr[-1].item() - - device = q.device - - output = torch.zeros((total_q, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device) - lse = torch.full((total_q, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - gqa_ratio = num_qo_heads // num_kv_heads - - q_f32 = q.to(torch.float32) - k_f32 = k.to(torch.float32) - v_f32 = v.to(torch.float32) - - for b in range(len_indptr - 1): - q_start = int(qo_indptr[b].item()) - q_end = int(qo_indptr[b + 1].item()) - - kv_start = int(kv_indptr[b].item()) - kv_end = int(kv_indptr[b + 1].item()) - - if q_start >= q_end or kv_start >= kv_end: - # No queries or KV for this batch element - continue - - # Get Q, K, V for this batch - q_batch = q_f32[q_start:q_end] # [num_q_tokens, num_qo_heads, head_dim] - k_batch = k_f32[kv_start:kv_end] # [num_kv_tokens, num_kv_heads, head_dim] - v_batch = v_f32[kv_start:kv_end] # [num_kv_tokens, num_kv_heads, head_dim] - - num_q_tokens = q_batch.shape[0] - num_kv_tokens = k_batch.shape[0] - delta = num_kv_tokens - num_q_tokens +# Constants from definition +NUM_QO_HEADS = 32 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 - k_expanded = k_batch.repeat_interleave(gqa_ratio, dim=1) - v_expanded = v_batch.repeat_interleave(gqa_ratio, dim=1) - # Compute attention scores: Q @ K^T - logits = torch.einsum("qhd,khd->qhk", q_batch, k_expanded) * sm_scale - - # For position q_idx, can attend to KV positions [0, min(q_idx + 1 + delta, num_kv_tokens)) - q_positions = torch.arange(num_q_tokens, device=device) # [num_q_tokens] - kv_positions = torch.arange(num_kv_tokens, device=device) # [num_kv_tokens] - - # Apply causal mask - causal_mask = kv_positions[None, :] < (q_positions[:, None] + 1 + delta) - logits = logits.masked_fill(~causal_mask[:, None, :], float("-inf")) - - # Compute 2-base LSE - lse_batch = torch.logsumexp(logits, dim=-1) / math.log(2.0) - lse[q_start:q_end] = lse_batch - - attn_weights = torch.softmax(logits, dim=-1) # [num_q_tokens, num_qo_heads, num_kv_tokens] - output_batch = torch.einsum("qhk,khd->qhd", attn_weights, v_expanded) - output[q_start:q_end] = output_batch.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - num_attention_heads=32, - num_key_value_heads=8, - head_dim=128, - causal=True, - device="cuda", -): +def generate_random_inputs(batch_size, max_q_len, max_kv_len, causal=True, device="cuda"): """Generate random inputs for ragged prefill testing.""" - - # Generate random query lengths for each batch element q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) - - # Generate random KV lengths for each batch element - # For prefill, KV length is typically >= query length (includes previous context) kv_lens = torch.zeros(batch_size, dtype=torch.int32) for i in range(batch_size): - # KV length should be at least as long as query length for causal attention - kv_lens[i] = torch.randint(q_lens[i].item(), max_kv_len + 1, (1,)).item() + if causal: + kv_lens[i] = torch.randint(q_lens[i].item(), max_kv_len + 1, (1,)).item() + else: + kv_lens[i] = torch.randint(1, max_kv_len + 1, (1,)).item() - # Create indptr arrays qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) qo_indptr[1:] = torch.cumsum(q_lens.to(device), dim=0) kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) kv_indptr[1:] = torch.cumsum(kv_lens.to(device), dim=0) - # Get total tokens total_q = qo_indptr[-1].item() total_kv = kv_indptr[-1].item() - # Generate tensors - q = torch.randn(total_q, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device) - k = torch.randn(total_kv, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device) - v = torch.randn(total_kv, num_key_value_heads, head_dim, dtype=torch.bfloat16, device=device) - - # Generate attention parameters - sm_scale = 1.0 / math.sqrt(head_dim) - sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) - - # Convert causal to tensor - causal = torch.tensor(causal, dtype=torch.bool, device=device) + q = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) + k = torch.randn(total_kv, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) + v = torch.randn(total_kv, NUM_KV_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=device) + sm_scale = torch.tensor(1.0 / math.sqrt(HEAD_DIM), dtype=torch.float32, device=device) return { "q": q, @@ -139,7 +63,7 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato """Test correctness of ragged prefill reference implementation against FlashInfer.""" print(f"\n{'='*60}") print( - f"Testing GQA Ragged Prefill batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}, causal={causal}" + f"Testing GQA Ragged Prefill batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}" ) print(f"{'='*60}") @@ -148,31 +72,12 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_attention_heads = 32 - num_key_value_heads = 8 - head_dim = 128 - - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - num_attention_heads, - num_key_value_heads, - head_dim, - causal, - device, - ) + inputs = generate_random_inputs(batch_size, max_q_len, max_kv_len, causal, device) print(f"Generated query lengths: {inputs['q_lens'].cpu().numpy()}") print(f"Generated KV lengths: {inputs['kv_lens'].cpu().numpy()}") - print(f"Total query tokens: {inputs['total_q']}") - print(f"Total KV tokens: {inputs['total_kv']}") - print(f"Causal mode: {inputs['causal'].item()}") - # Run reference implementation - print("\nRunning reference implementation...") + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q"], inputs["k"], @@ -182,172 +87,50 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato inputs["sm_scale"], ) - # Setup FlashInfer print("\nSetting up FlashInfer...") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) - prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffer, kv_layout="NHD" # Layout for K/V tensors + workspace_buffer, kv_layout="NHD" ) - # Plan the attention computation prefill_wrapper.plan( qo_indptr=inputs["qo_indptr"], kv_indptr=inputs["kv_indptr"], - num_qo_heads=num_attention_heads, - num_kv_heads=num_key_value_heads, - head_dim_qk=head_dim, # head dimension for query/key - head_dim_vo=head_dim, # head dimension for value/output (same as qk for standard attention) - causal=inputs["causal"].item(), # Use the randomly generated causal flag - sm_scale=inputs["sm_scale"], # Scale factor for softmax + num_qo_heads=NUM_QO_HEADS, + num_kv_heads=NUM_KV_HEADS, + head_dim_qk=HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=inputs["causal"], + sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) - # Run FlashInfer print("Running FlashInfer...") fi_output, fi_lse = prefill_wrapper.run(inputs["q"], inputs["k"], inputs["v"], return_lse=True) - # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - # Convert to float32 for comparison - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - # Compute errors for output tensor - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - # Compute cosine similarity and MSE for output tensor - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - # Compare LSE values - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - # Check if outputs match within tolerance - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: - print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - if not output_close: - # Find indices with largest errors for debugging - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - # Convert flat index back to 3D indices - q_idx = idx // (num_attention_heads * head_dim) - head_idx = (idx % (num_attention_heads * head_dim)) // head_dim - dim_idx = idx % head_dim - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}, dim={dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - # Find LSE errors - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - q_idx = idx // num_attention_heads - head_idx = idx % num_attention_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [q_idx={q_idx}, head={head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) + print(f"\n✗ FAILED: Outputs differ beyond tolerance") return all_close def main(): - """Run comprehensive tests.""" - print("Testing Batch GQA Ragged Prefill Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, max_q_len, max_kv_len, causal) - (1, 8, 16, True), # Single batch, small, causal - # (1, 8, 16, False), # Single batch, small, non-causal - (4, 16, 32, True), # Small batch, causal - # (4, 16, 32, False), # Small batch, non-causal - (8, 32, 64, True), # Medium batch, causal - # (8, 32, 64, False), # Medium batch, non-causal - (16, 64, 128, True), # Large batch, causal - # (16, 64, 128, False), # Large batch, non-causal - (32, 128, 256, True), # Very large batch, causal - # (32, 128, 256, False), # Very large batch, non-causal - ] - - passed = 0 - total = len(test_configs) - - for batch_size, max_q_len, max_kv_len, causal in test_configs: - try: - if test_correctness(batch_size, max_q_len, max_kv_len, causal): - passed += 1 - except Exception as e: - print(f"✗ Test failed with exception: {str(e)}") - import traceback - - traceback.print_exc() - - print(f"\n{'='*60}") - print(f"Summary: {passed}/{total} tests passed") - print(f"{'='*60}") - - if passed == total: - print("✓ All tests passed!") - else: - print(f"✗ {total - passed} tests failed") + print("Testing Batch GQA Ragged Prefill Reference Implementation (from definition)") + test_configs = [(1, 8, 16, True), (4, 16, 32, True), (8, 32, 64, True), (16, 64, 128, True)] + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nSummary: {passed}/{len(test_configs)} tests passed\n{'='*60}") if __name__ == "__main__": diff --git a/flashinfer_trace/tests/references/test_mla_paged_decode_h16_ckv512_kpe64_ps1.py b/flashinfer_trace/tests/references/test_mla_paged_decode_h16_ckv512_kpe64_ps1.py index 93ad6ecb..435db8fe 100644 --- a/flashinfer_trace/tests/references/test_mla_paged_decode_h16_ckv512_kpe64_ps1.py +++ b/flashinfer_trace/tests/references/test_mla_paged_decode_h16_ckv512_kpe64_ps1.py @@ -1,100 +1,38 @@ -import math +""" +Test MLA paged decode h16_ckv512_kpe64_ps1 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" import flashinfer import numpy as np import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("mla_paged_decode_h16_ckv512_kpe64_ps1") -@torch.no_grad() -def run(q_nope, q_pe, ckv_cache, kpe_cache, kv_indptr, kv_indices, sm_scale): - batch_size, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - page_size = ckv_cache.shape[1] - len_indptr = kv_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 16 - assert head_dim_ckv == 512 - assert head_dim_kpe == 64 - assert page_size == 1 - - # Check constraints - assert len_indptr == batch_size + 1 - assert num_kv_indices == kv_indptr[-1].item() - - device = q_nope.device - - Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv] - Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe] - - output = torch.zeros( - (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device - ) - lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - for b in range(batch_size): - page_beg = int(kv_indptr[b].item()) - page_end = int(kv_indptr[b + 1].item()) - - if page_beg >= page_end: - # No KV cache for this batch element - output[b].zero_() - continue - - pages = kv_indices[page_beg:page_end] - # Derive kv_len from kv_indptr (for page_size=1, num_pages == num_tokens) - L_tokens = page_end - page_beg - - if L_tokens <= 0 or pages.numel() == 0: - output[b].zero_() - continue - - # Pages are token indices for page_size=1 - tok_idx = pages[:L_tokens].to(torch.long) +# Constants from definition +NUM_QO_HEADS = 16 +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 +PAGE_SIZE = 1 - Kc = Kc_all[tok_idx] # [L_tokens, head_dim_ckv] - Kp = Kp_all[tok_idx] # [L_tokens, head_dim_kpe] - qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv] - qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe] - logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_qo_heads, L_tokens] - logits_scaled = logits * sm_scale - - # Compute 2-base LSE - lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) # [num_qo_heads, L_tokens] - out = attn @ Kc # [num_qo_heads, head_dim_ckv] - output[b] = out.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_seq_len, - num_qo_heads=16, - head_dim_ckv=512, - head_dim_kpe=64, - page_size=1, - device="cuda", -): +def generate_random_inputs(batch_size, max_seq_len, device="cuda"): """Generate random inputs for MLA testing.""" - # Generate random sequence lengths for each batch seq_lens = torch.randint(1, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device) - # Calculate total pages needed - # Since page_size = 1, num_pages = total_tokens + # Calculate total pages needed (page_size=1) total_pages_needed = seq_lens.sum().item() # Generate kv_indptr based on sequence lengths kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) kv_indptr[1:] = torch.cumsum(seq_lens, dim=0) - # Generate kv_indices (page indices for each sequence) - # We'll use consecutive pages for simplicity + # Generate kv_indices kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) # kv_len_arr stores the actual sequence lengths @@ -102,19 +40,17 @@ def generate_random_inputs( # Generate query tensors q_nope = torch.randn( - batch_size, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device + batch_size, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device ) - q_pe = torch.randn(batch_size, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) + q_pe = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) # Generate compressed KV and positional caches - # Add some extra pages to simulate a real scenario num_pages = total_pages_needed + 100 - ckv_cache = torch.randn(num_pages, page_size, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, page_size, head_dim_kpe, dtype=torch.bfloat16, device=device) + ckv_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - # Generate attention parameters # MLA uses head dimension before matrix absorption (128 + 64 = 192) - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) + sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) # For decode, qo_indptr is just [0, 1, 2, ..., batch_size] @@ -145,22 +81,14 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_qo_heads = 16 - head_dim_ckv = 512 - head_dim_kpe = 64 - page_size = 1 - # Generate inputs - inputs = generate_random_inputs( - batch_size, max_seq_len, num_qo_heads, head_dim_ckv, head_dim_kpe, page_size, device - ) + inputs = generate_random_inputs(batch_size, max_seq_len, device) print(f"Generated sequences with lengths: {inputs['seq_lens'].cpu().numpy()}") print(f"Total pages used: {inputs['kv_indices'].shape[0]}") - # Run reference implementation - print("\nRunning reference implementation...") + # Run reference implementation from definition + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q_nope"], inputs["q_pe"], @@ -175,9 +103,7 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): print("\nSetting up FlashInfer...") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( - workspace_buffer, backend="auto" # Will choose the best backend automatically - ) + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace_buffer, backend="auto") # Plan the attention computation mla_wrapper.plan( @@ -185,11 +111,11 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): kv_indptr=inputs["kv_indptr"], kv_indices=inputs["kv_indices"], kv_len_arr=inputs["kv_len_arr"], - num_heads=num_qo_heads, - head_dim_ckv=head_dim_ckv, - head_dim_kpe=head_dim_kpe, - page_size=page_size, - causal=False, # For decode, causal doesn't matter as each query has length 1 + num_heads=NUM_QO_HEADS, + head_dim_ckv=HEAD_DIM_CKV, + head_dim_kpe=HEAD_DIM_KPE, + page_size=PAGE_SIZE, + causal=False, sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, @@ -203,117 +129,27 @@ def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - # Convert to float32 for comparison - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - # Compute errors for output tensor - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() - - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - # Compute cosine similarity and MSE for output tensor - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - # Compare LSE values - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - # Check if outputs match within tolerance - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") + + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") else: print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - if not output_close: - # Find indices with largest errors for debugging - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - # Convert flat index back to 3D indices - batch_idx = idx // (num_qo_heads * head_dim_ckv) - head_idx = (idx % (num_qo_heads * head_dim_ckv)) // head_dim_ckv - dim_idx = idx % head_dim_ckv - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}, {dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - # Find LSE errors - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - batch_idx = idx // num_qo_heads - head_idx = idx % num_qo_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) - return all_close def main(): """Run comprehensive tests.""" - print("Testing Batch MLA Paged Decode Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, max_seq_len) - (1, 16), # Single batch - (4, 32), # Small batch - (8, 64), # Medium batch - (16, 128), # Large batch - (32, 256), # Very large batch - ] + print("Testing Batch MLA Paged Decode Reference Implementation (from definition)") + + test_configs = [(1, 16), (4, 32), (8, 64), (16, 128), (32, 256)] passed = 0 total = len(test_configs) diff --git a/flashinfer_trace/tests/references/test_mla_paged_decode_h16_ckv512_kpe64_ps64.py b/flashinfer_trace/tests/references/test_mla_paged_decode_h16_ckv512_kpe64_ps64.py index 9e70d7fc..6b6ad4f0 100644 --- a/flashinfer_trace/tests/references/test_mla_paged_decode_h16_ckv512_kpe64_ps64.py +++ b/flashinfer_trace/tests/references/test_mla_paged_decode_h16_ckv512_kpe64_ps64.py @@ -1,136 +1,51 @@ -import math +""" +Test MLA paged decode h16_ckv512_kpe64_ps64 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" import flashinfer import numpy as np import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("mla_paged_decode_h16_ckv512_kpe64_ps64") -@torch.no_grad() -def run(q_nope, q_pe, ckv_cache, kpe_cache, kv_indptr, kv_indices, kv_last_page_len, sm_scale): - batch_size, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - page_size = ckv_cache.shape[1] - len_indptr = kv_indptr.shape[0] - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 16 - assert head_dim_ckv == 512 - assert head_dim_kpe == 64 - assert page_size == 64 - - # Check constraints - assert len_indptr == batch_size + 1 - assert num_kv_indices == kv_indptr[-1].item() - - device = q_nope.device +# Constants from definition +NUM_QO_HEADS = 16 +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 +PAGE_SIZE = 64 - ckv_cache_f32 = ckv_cache.to(torch.float32) - kpe_cache_f32 = kpe_cache.to(torch.float32) - output = torch.zeros( - (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device - ) - lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - for b in range(batch_size): - page_beg = int(kv_indptr[b].item()) - page_end = int(kv_indptr[b + 1].item()) - last_page_len = int(kv_last_page_len[b].item()) - - if page_beg >= page_end: - output[b].zero_() - continue - - page_ids = kv_indices[page_beg:page_end].to(torch.long) - num_pages_for_seq = page_ids.shape[0] - - # Calculate total tokens - num_full_pages = num_pages_for_seq - 1 - L_tokens = num_full_pages * page_size + last_page_len - - if L_tokens <= 0: - output[b].zero_() - continue - - # Gather Kc and Kp from pages - Kc = torch.zeros((L_tokens, head_dim_ckv), dtype=torch.float32, device=device) - Kp = torch.zeros((L_tokens, head_dim_kpe), dtype=torch.float32, device=device) - - token_idx = 0 - for p_idx, page_id in enumerate(page_ids): - if p_idx < num_full_pages: - Kc[token_idx : token_idx + page_size] = ckv_cache_f32[page_id] - Kp[token_idx : token_idx + page_size] = kpe_cache_f32[page_id] - token_idx += page_size - else: - Kc[token_idx : token_idx + last_page_len] = ckv_cache_f32[page_id, :last_page_len] - Kp[token_idx : token_idx + last_page_len] = kpe_cache_f32[page_id, :last_page_len] - token_idx += last_page_len - - qn = q_nope[b].to(torch.float32) - qp = q_pe[b].to(torch.float32) - - logits = (qn @ Kc.T) + (qp @ Kp.T) - logits_scaled = logits * sm_scale - - lse[b] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) - out = attn @ Kc - output[b] = out.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_seq_len, - num_qo_heads=16, - head_dim_ckv=512, - head_dim_kpe=64, - page_size=64, - device="cuda", -): - """Generate random inputs for MLA testing.""" - - # Generate random sequence lengths for each batch +def generate_random_inputs(batch_size, max_seq_len, device="cuda"): + """Generate random inputs for MLA testing with page_size=64.""" seq_lens = torch.randint(1, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device) - # Calculate pages needed for each sequence - pages_per_seq = (seq_lens + page_size - 1) // page_size # Ceiling division + pages_per_seq = (seq_lens + PAGE_SIZE - 1) // PAGE_SIZE total_pages_needed = pages_per_seq.sum().item() - # Generate kv_indptr based on pages per sequence kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) kv_indptr[1:] = torch.cumsum(pages_per_seq, dim=0) - # Generate kv_indices (page indices for each sequence) kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) - - # Calculate last_page_len for each sequence - kv_last_page_len = ((seq_lens - 1) % page_size) + 1 - - # kv_len_arr stores the actual sequence lengths kv_len_arr = seq_lens.clone() + kv_last_page_len = ((seq_lens - 1) % PAGE_SIZE) + 1 - # Generate query tensors q_nope = torch.randn( - batch_size, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device + batch_size, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device ) - q_pe = torch.randn(batch_size, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) + q_pe = torch.randn(batch_size, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - # Generate compressed KV and positional caches num_pages = total_pages_needed + 100 - ckv_cache = torch.randn(num_pages, page_size, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, page_size, head_dim_kpe, dtype=torch.bfloat16, device=device) + ckv_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(num_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - # Generate attention parameters - # MLA uses head dimension before matrix absorption (128 + 64 = 192) - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) + sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) - # For decode, qo_indptr is just [0, 1, 2, ..., batch_size] qo_indptr = torch.arange(0, batch_size + 1, dtype=torch.int32, device=device) return { @@ -140,8 +55,8 @@ def generate_random_inputs( "kpe_cache": kpe_cache, "kv_indptr": kv_indptr, "kv_indices": kv_indices, - "kv_last_page_len": kv_last_page_len, "kv_len_arr": kv_len_arr, + "kv_last_page_len": kv_last_page_len, "sm_scale": sm_scale, "qo_indptr": qo_indptr, "seq_lens": seq_lens, @@ -151,7 +66,7 @@ def generate_random_inputs( def test_correctness(batch_size=4, max_seq_len=256, atol=1e-2, rtol=5e-2): """Test correctness of MLA reference implementation against FlashInfer.""" print(f"\n{'='*60}") - print(f"Testing MLA batch_size={batch_size}, max_seq_len={max_seq_len}") + print(f"Testing MLA (ps64) batch_size={batch_size}, max_seq_len={max_seq_len}") print(f"{'='*60}") device = "cuda" if torch.cuda.is_available() else "cpu" @@ -159,23 +74,12 @@ def test_correctness(batch_size=4, max_seq_len=256, atol=1e-2, rtol=5e-2): print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_qo_heads = 16 - head_dim_ckv = 512 - head_dim_kpe = 64 - page_size = 64 - - # Generate inputs - inputs = generate_random_inputs( - batch_size, max_seq_len, num_qo_heads, head_dim_ckv, head_dim_kpe, page_size, device - ) + inputs = generate_random_inputs(batch_size, max_seq_len, device) print(f"Generated sequences with lengths: {inputs['seq_lens'].cpu().numpy()}") print(f"Last page lengths: {inputs['kv_last_page_len'].cpu().numpy()}") - print(f"Total pages used: {inputs['kv_indices'].shape[0]}") - # Run reference implementation - print("\nRunning reference implementation...") + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q_nope"], inputs["q_pe"], @@ -187,153 +91,52 @@ def test_correctness(batch_size=4, max_seq_len=256, atol=1e-2, rtol=5e-2): inputs["sm_scale"], ) - # Setup FlashInfer print("\nSetting up FlashInfer...") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace_buffer, backend="auto") - # Plan the attention computation mla_wrapper.plan( qo_indptr=inputs["qo_indptr"], kv_indptr=inputs["kv_indptr"], kv_indices=inputs["kv_indices"], kv_len_arr=inputs["kv_len_arr"], - num_heads=num_qo_heads, - head_dim_ckv=head_dim_ckv, - head_dim_kpe=head_dim_kpe, - page_size=page_size, - causal=False, # For decode, causal doesn't matter as each query has length 1 + num_heads=NUM_QO_HEADS, + head_dim_ckv=HEAD_DIM_CKV, + head_dim_kpe=HEAD_DIM_KPE, + page_size=PAGE_SIZE, + causal=False, sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) - # Run FlashInfer print("Running FlashInfer...") fi_output, fi_lse = mla_wrapper.run( inputs["q_nope"], inputs["q_pe"], inputs["ckv_cache"], inputs["kpe_cache"], return_lse=True ) - # Compare outputs print("\nComparing outputs...") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - ref_o_f32 = ref_o.float() - fi_output_f32 = fi_output.float() - - abs_diff = torch.abs(ref_o_f32 - fi_output_f32) - rel_diff = abs_diff / (torch.abs(fi_output_f32) + 1e-8) - - max_abs_diff = abs_diff.max().item() - max_rel_diff = rel_diff.max().item() - mean_abs_diff = abs_diff.mean().item() - mean_rel_diff = rel_diff.mean().item() + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - print(f"\nOutput tensor comparison:") - print(f"Max absolute difference: {max_abs_diff:.6e}") - print(f"Max relative difference: {max_rel_diff:.6e}") - print(f"Mean absolute difference: {mean_abs_diff:.6e}") - print(f"Mean relative difference: {mean_rel_diff:.6e}") - - cos_sim = torch.nn.functional.cosine_similarity( - ref_o_f32.flatten(), fi_output_f32.flatten(), dim=0 - ).item() - mse = torch.mean((ref_o_f32 - fi_output_f32) ** 2).item() - print(f"Cosine similarity: {cos_sim:.6f}") - print(f"MSE: {mse:.6e}") - - lse_abs_diff = torch.abs(ref_lse - fi_lse) - lse_rel_diff = lse_abs_diff / (torch.abs(fi_lse) + 1e-8) - - lse_max_abs_diff = lse_abs_diff.max().item() - lse_max_rel_diff = lse_rel_diff.max().item() - lse_mean_abs_diff = lse_abs_diff.mean().item() - lse_mean_rel_diff = lse_rel_diff.mean().item() - - print(f"\nLSE comparison:") - print(f"Max absolute difference: {lse_max_abs_diff:.6e}") - print(f"Max relative difference: {lse_max_rel_diff:.6e}") - print(f"Mean absolute difference: {lse_mean_abs_diff:.6e}") - print(f"Mean relative difference: {lse_mean_rel_diff:.6e}") - - output_close = torch.allclose(ref_o_f32, fi_output_f32, atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, fi_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: - print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - if not output_close: - flat_abs_diff = abs_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - batch_idx = idx // (num_qo_heads * head_dim_ckv) - head_idx = (idx % (num_qo_heads * head_dim_ckv)) // head_dim_ckv - dim_idx = idx % head_dim_ckv - - ref_val = ref_o_f32.flatten()[idx].item() - fi_val = fi_output_f32.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}, {dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - flat_lse_diff = lse_abs_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - batch_idx = idx // num_qo_heads - head_idx = idx % num_qo_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = fi_lse.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) + print(f"\n✗ FAILED: Outputs differ beyond tolerance") return all_close def main(): - """Run comprehensive tests.""" - print("Testing Batch MLA Paged Decode Reference Implementation (page_size=64)") - - test_configs = [(1, 64), (4, 128), (8, 256), (16, 512), (32, 1024)] - - passed = 0 - total = len(test_configs) - - for batch_size, max_seq_len in test_configs: - try: - if test_correctness(batch_size, max_seq_len): - passed += 1 - except Exception as e: - print(f"✗ Test failed with exception: {str(e)}") - import traceback - - traceback.print_exc() - - print(f"\n{'='*60}") - print(f"Summary: {passed}/{total} tests passed") - print(f"{'='*60}") - - if passed == total: - print("✓ All tests passed!") - else: - print(f"✗ {total - passed} tests failed") + print("Testing Batch MLA Paged Decode Reference Implementation (page_size=64, from definition)") + test_configs = [(1, 64), (4, 128), (8, 256), (16, 512)] + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nSummary: {passed}/{len(test_configs)} tests passed\n{'='*60}") if __name__ == "__main__": diff --git a/flashinfer_trace/tests/references/test_mla_paged_prefill_h16_ckv512_kpe64_ps1.py b/flashinfer_trace/tests/references/test_mla_paged_prefill_h16_ckv512_kpe64_ps1.py index 9d0f0849..07f1810d 100644 --- a/flashinfer_trace/tests/references/test_mla_paged_prefill_h16_ckv512_kpe64_ps1.py +++ b/flashinfer_trace/tests/references/test_mla_paged_prefill_h16_ckv512_kpe64_ps1.py @@ -1,143 +1,57 @@ -import math +""" +Test MLA paged prefill h16_ckv512_kpe64_ps1 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" import flashinfer import numpy as np import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("mla_paged_prefill_causal_h16_ckv512_kpe64_ps1") -@torch.no_grad() -def run(q_nope, q_pe, ckv_cache, kpe_cache, qo_indptr, kv_indptr, kv_indices, sm_scale): - total_q, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - page_size = ckv_cache.shape[1] - len_indptr = qo_indptr.shape[0] - batch_size = len_indptr - 1 - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 16 - assert head_dim_ckv == 512 - assert head_dim_kpe == 64 - assert page_size == 1 - - # Check constraints - assert total_q == qo_indptr[-1].item() - assert num_kv_indices == kv_indptr[-1].item() - - device = q_nope.device - - Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv] - Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe] - - output = torch.zeros((total_q, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device) - lse = torch.full((total_q, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - for b in range(batch_size): - q_start = int(qo_indptr[b].item()) - q_end = int(qo_indptr[b + 1].item()) - - page_beg = int(kv_indptr[b].item()) - page_end = int(kv_indptr[b + 1].item()) - - if q_start >= q_end or page_beg >= page_end: - # No queries or KV for this batch element - continue - - kv_len = page_end - page_beg - pages = kv_indices[page_beg:page_end] - - # Since page_size=1, pages are token indices - tok_idx = pages[:kv_len].to(torch.long) - Kc = Kc_all[tok_idx] # [kv_len, head_dim_ckv] - Kp = Kp_all[tok_idx] # [kv_len, head_dim_kpe] - - q_nope_batch = q_nope[q_start:q_end].to(torch.float32) # [q_len, num_heads, head_dim_ckv] - q_pe_batch = q_pe[q_start:q_end].to(torch.float32) # [q_len, num_heads, head_dim_kpe] - - q_len = q_end - q_start +# Constants from definition +NUM_QO_HEADS = 16 +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 +PAGE_SIZE = 1 - for i in range(q_len): - qn = q_nope_batch[i] # [num_heads, head_dim_ckv] - qp = q_pe_batch[i] # [num_heads, head_dim_kpe] - logits = (qn @ Kc.T) + (qp @ Kp.T) # [num_heads, kv_len] - logits_scaled = logits * sm_scale - - # Apply causal mask - prefix_len = kv_len - q_len # Number of previously cached tokens - query_abs_pos = prefix_len + i # Absolute position of current query - - causal_mask = torch.arange(kv_len, device=logits_scaled.device) > query_abs_pos - logits_scaled.masked_fill_(causal_mask.unsqueeze(0), -float("inf")) - - # Compute 2-base LSE - lse[q_start + i] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) # [num_heads, L_tokens] - out = attn @ Kc # [num_heads, head_dim_ckv] - output[q_start + i] = out.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - num_qo_heads=16, - head_dim_ckv=512, - head_dim_kpe=64, - page_size=1, - causal=True, - device="cuda", -): - """Generate random inputs for MLA paged prefill testing.""" - - # Generate random sequence lengths for each batch - q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32, device=device) - kv_lens = torch.randint(1, max_kv_len + 1, (batch_size,), dtype=torch.int32, device=device) - - # For prefill, ensure kv_len >= q_len for causal attention +def generate_random_inputs(batch_size, max_q_len, max_kv_len, causal=True, device="cuda"): + """Generate random inputs for MLA prefill testing.""" + q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) + kv_lens = torch.zeros(batch_size, dtype=torch.int32) for i in range(batch_size): - kv_lens[i] = max(kv_lens[i], q_lens[i]) - - total_q = q_lens.sum().item() + if causal: + kv_lens[i] = torch.randint(q_lens[i].item(), max_kv_len + 1, (1,)).item() + else: + kv_lens[i] = torch.randint(1, max_kv_len + 1, (1,)).item() - # Calculate total pages needed (since page_size=1, num_pages = total_kv_tokens) - total_pages_needed = kv_lens.sum().item() - - # Generate qo_indptr based on query lengths qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) - qo_indptr[1:] = torch.cumsum(q_lens, dim=0) + qo_indptr[1:] = torch.cumsum(q_lens.to(device), dim=0) - # Generate kv_indptr based on KV lengths kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) - kv_indptr[1:] = torch.cumsum(kv_lens, dim=0) - - # Generate kv_indices (page indices for each sequence) - kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) + kv_indptr[1:] = torch.cumsum(kv_lens.to(device), dim=0) - # kv_len_arr stores the actual KV sequence lengths - kv_len_arr = kv_lens.clone() + total_q = qo_indptr[-1].item() + num_kv_indices = kv_indptr[-1].item() + kv_len_arr = kv_lens.to(device) - # Generate query tensors with Matrix Absorption dimensions - q_nope = torch.randn(total_q, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device) - q_pe = torch.randn(total_q, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) + max_pages = num_kv_indices + 100 + kv_indices = torch.arange(num_kv_indices, dtype=torch.int32, device=device) + last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) - # Generate compressed KV and positional caches - # Add some extra pages to simulate a real scenario - num_pages = total_pages_needed + 100 - ckv_cache = torch.randn(num_pages, page_size, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, page_size, head_dim_kpe, dtype=torch.bfloat16, device=device) + q_nope = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + q_pe = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) + ckv_cache = torch.randn(max_pages, PAGE_SIZE, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(max_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - # Generate attention parameters - # MLA uses head dimension before matrix absorption - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) + sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) - # Convert causal to tensor - causal = torch.tensor(causal, dtype=torch.bool, device=device) - return { "q_nope": q_nope, "q_pe": q_pe, @@ -147,19 +61,20 @@ def generate_random_inputs( "kv_indptr": kv_indptr, "kv_indices": kv_indices, "kv_len_arr": kv_len_arr, - "sm_scale": sm_scale, - "causal": causal, + "last_page_len": last_page_len, "q_lens": q_lens, "kv_lens": kv_lens, "total_q": total_q, + "sm_scale": sm_scale, + "causal": causal, } def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, atol=1e-2, rtol=5e-2): - """Test correctness of MLA paged prefill reference implementation against FlashInfer.""" + """Test correctness of MLA prefill reference implementation against FlashInfer.""" print(f"\n{'='*60}") print( - f"Testing batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}, causal={causal}" + f"Testing MLA Paged Prefill batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}" ) print(f"{'='*60}") @@ -168,33 +83,12 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_qo_heads = 16 - head_dim_ckv = 512 - head_dim_kpe = 64 - page_size = 1 - - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - num_qo_heads, - head_dim_ckv, - head_dim_kpe, - page_size, - causal, - device, - ) + inputs = generate_random_inputs(batch_size, max_q_len, max_kv_len, causal, device) print(f"Generated query lengths: {inputs['q_lens'].cpu().numpy()}") print(f"Generated KV lengths: {inputs['kv_lens'].cpu().numpy()}") - print(f"Total Q tokens: {inputs['total_q']}") - print(f"Total pages used: {inputs['kv_indices'].shape[0]}") - print(f"Causal: {inputs['causal'].item()}") - # Run reference implementation - print("\nRunning reference implementation...") + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q_nope"], inputs["q_pe"], @@ -206,148 +100,52 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=64, causal=True, ato inputs["sm_scale"], ) - # Setup FlashInfer print("\nSetting up FlashInfer...") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - - # For paged prefill with Matrix Absorption, use BatchMLAPagedAttentionWrapper mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace_buffer, backend="auto") - # Plan the attention computation mla_wrapper.plan( qo_indptr=inputs["qo_indptr"], kv_indptr=inputs["kv_indptr"], kv_indices=inputs["kv_indices"], kv_len_arr=inputs["kv_len_arr"], - num_heads=num_qo_heads, - head_dim_ckv=head_dim_ckv, - head_dim_kpe=head_dim_kpe, - page_size=page_size, - causal=inputs["causal"].item(), # Causal masking configuration + num_heads=NUM_QO_HEADS, + head_dim_ckv=HEAD_DIM_CKV, + head_dim_kpe=HEAD_DIM_KPE, + page_size=PAGE_SIZE, + causal=inputs["causal"], sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) - # Run FlashInfer print("Running FlashInfer...") - flashinfer_o, flashinfer_lse = mla_wrapper.run( - q_nope=inputs["q_nope"], - q_pe=inputs["q_pe"], - ckv_cache=inputs["ckv_cache"], - kpe_cache=inputs["kpe_cache"], - return_lse=True, + fi_output, fi_lse = mla_wrapper.run( + inputs["q_nope"], inputs["q_pe"], inputs["ckv_cache"], inputs["kpe_cache"], return_lse=True ) - # Compare outputs print("\nComparing outputs...") - print(f"Reference output shape: {ref_o.shape}") - print(f"FlashInfer output shape: {flashinfer_o.shape}") - print(f"Reference LSE shape: {ref_lse.shape}") - print(f"FlashInfer LSE shape: {flashinfer_lse.shape}") + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - # Check numerical accuracy - o_diff = torch.abs(ref_o - flashinfer_o) - lse_diff = torch.abs(ref_lse - flashinfer_lse) + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - print(f"\nOutput max diff: {o_diff.max().item():.6f}") - print(f"Output mean diff: {o_diff.mean().item():.6f}") - print(f"LSE max diff: {lse_diff.max().item():.6f}") - print(f"LSE mean diff: {lse_diff.mean().item():.6f}") - - # Check if outputs match within tolerance - output_close = torch.allclose(ref_o.float(), flashinfer_o.float(), atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, flashinfer_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: - print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - if not output_close: - # Find indices with largest errors for debugging - flat_abs_diff = o_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - # Convert flat index back to 3D indices - _, num_qo_heads, head_dim_ckv = ref_o.shape - batch_idx = idx // (num_qo_heads * head_dim_ckv) - head_idx = (idx % (num_qo_heads * head_dim_ckv)) // head_dim_ckv - dim_idx = idx % head_dim_ckv - - ref_val = ref_o.flatten()[idx].item() - fi_val = flashinfer_o.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}, {dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - # Find LSE errors - flat_lse_diff = lse_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - _, num_qo_heads = ref_lse.shape - batch_idx = idx // num_qo_heads - head_idx = idx % num_qo_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = flashinfer_lse.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) + print(f"\n✗ FAILED: Outputs differ beyond tolerance") return all_close def main(): - """Run comprehensive tests.""" - print("Testing Batch MLA Paged Prefill Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, max_q_len, max_kv_len, causal) - (1, 8, 16, True), # Small causal - # (1, 8, 16, False), # Small non-causal - (4, 16, 32, True), # Medium causal - # (4, 16, 32, False), # Medium non-causal - (8, 32, 64, True), # Large causal - # (8, 32, 64, False), # Large non-causal - ] - - passed = 0 - total = len(test_configs) - - for batch_size, max_q_len, max_kv_len, causal in test_configs: - try: - if test_correctness(batch_size, max_q_len, max_kv_len, causal): - passed += 1 - except Exception as e: - print(f"✗ Test failed with exception: {str(e)}") - import traceback - - traceback.print_exc() - - print(f"\n{'='*60}") - print(f"Summary: {passed}/{total} tests passed") - print(f"{'='*60}") - - if passed == total: - print("✓ All tests passed!") - else: - print(f"✗ {total - passed} tests failed") + print("Testing Batch MLA Paged Prefill Reference Implementation (from definition)") + test_configs = [(1, 8, 16, True), (4, 16, 32, True), (8, 32, 64, True)] + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nSummary: {passed}/{len(test_configs)} tests passed\n{'='*60}") if __name__ == "__main__": diff --git a/flashinfer_trace/tests/references/test_mla_paged_prefill_h16_ckv512_kpe64_ps64.py b/flashinfer_trace/tests/references/test_mla_paged_prefill_h16_ckv512_kpe64_ps64.py index 9199a23a..b77b35d9 100644 --- a/flashinfer_trace/tests/references/test_mla_paged_prefill_h16_ckv512_kpe64_ps64.py +++ b/flashinfer_trace/tests/references/test_mla_paged_prefill_h16_ckv512_kpe64_ps64.py @@ -1,161 +1,59 @@ -import math +""" +Test MLA paged prefill h16_ckv512_kpe64_ps64 reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" import flashinfer import numpy as np import torch +from test_utils import compare_tensors, get_reference_run, print_comparison_metrics +# Load reference implementation from definition +run = get_reference_run("mla_paged_prefill_causal_h16_ckv512_kpe64_ps64") -@torch.no_grad() -def run( - q_nope, q_pe, ckv_cache, kpe_cache, qo_indptr, kv_indptr, kv_indices, kv_last_page_len, sm_scale -): - total_q, num_qo_heads, head_dim_ckv = q_nope.shape - head_dim_kpe = q_pe.shape[-1] - page_size = ckv_cache.shape[1] - len_indptr = qo_indptr.shape[0] - batch_size = len_indptr - 1 - num_kv_indices = kv_indices.shape[0] - - # Check constants - assert num_qo_heads == 16 - assert head_dim_ckv == 512 - assert head_dim_kpe == 64 - assert page_size == 64 - - # Check constraints - assert total_q == qo_indptr[-1].item() - assert num_kv_indices == kv_indptr[-1].item() - - device = q_nope.device - - ckv_cache_f32 = ckv_cache.to(torch.float32) - kpe_cache_f32 = kpe_cache.to(torch.float32) - - output = torch.zeros((total_q, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=device) - lse = torch.full((total_q, num_qo_heads), -float("inf"), dtype=torch.float32, device=device) - - for b in range(batch_size): - q_start = int(qo_indptr[b].item()) - q_end = int(qo_indptr[b + 1].item()) - - page_beg = int(kv_indptr[b].item()) - page_end = int(kv_indptr[b + 1].item()) - last_page_len = int(kv_last_page_len[b].item()) - - if q_start >= q_end or page_beg >= page_end: - continue - - page_ids = kv_indices[page_beg:page_end].to(torch.long) - num_pages_for_seq = page_ids.shape[0] - - # Calculate total KV tokens - num_full_pages = num_pages_for_seq - 1 - kv_len = num_full_pages * page_size + last_page_len - - # Gather Kc and Kp from pages - Kc = torch.zeros((kv_len, head_dim_ckv), dtype=torch.float32, device=device) - Kp = torch.zeros((kv_len, head_dim_kpe), dtype=torch.float32, device=device) - - token_idx = 0 - for p_idx, page_id in enumerate(page_ids): - if p_idx < num_full_pages: - Kc[token_idx : token_idx + page_size] = ckv_cache_f32[page_id] - Kp[token_idx : token_idx + page_size] = kpe_cache_f32[page_id] - token_idx += page_size - else: - Kc[token_idx : token_idx + last_page_len] = ckv_cache_f32[page_id, :last_page_len] - Kp[token_idx : token_idx + last_page_len] = kpe_cache_f32[page_id, :last_page_len] - token_idx += last_page_len - - q_nope_batch = q_nope[q_start:q_end].to(torch.float32) - q_pe_batch = q_pe[q_start:q_end].to(torch.float32) +# Constants from definition +NUM_QO_HEADS = 16 +HEAD_DIM_CKV = 512 +HEAD_DIM_KPE = 64 +PAGE_SIZE = 64 - q_len = q_end - q_start - for i in range(q_len): - qn = q_nope_batch[i] - qp = q_pe_batch[i] - - logits = (qn @ Kc.T) + (qp @ Kp.T) - logits_scaled = logits * sm_scale - - # Apply causal mask - prefix_len = kv_len - q_len - query_abs_pos = prefix_len + i - - causal_mask = torch.arange(kv_len, device=logits_scaled.device) > query_abs_pos - logits_scaled.masked_fill_(causal_mask.unsqueeze(0), -float("inf")) - - lse[q_start + i] = torch.logsumexp(logits_scaled, dim=-1) / math.log(2.0) - - attn = torch.softmax(logits_scaled, dim=-1) - out = attn @ Kc - output[q_start + i] = out.to(torch.bfloat16) - - return output, lse - - -def generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - num_qo_heads=16, - head_dim_ckv=512, - head_dim_kpe=64, - page_size=64, - causal=True, - device="cuda", -): - """Generate random inputs for MLA paged prefill testing.""" - - # Generate random sequence lengths for each batch - q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32, device=device) - kv_lens = torch.randint(1, max_kv_len + 1, (batch_size,), dtype=torch.int32, device=device) - - # For prefill, ensure kv_len >= q_len for causal attention +def generate_random_inputs(batch_size, max_q_len, max_kv_len, causal=True, device="cuda"): + """Generate random inputs for MLA prefill testing with page_size=64.""" + q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) + kv_lens = torch.zeros(batch_size, dtype=torch.int32) for i in range(batch_size): - kv_lens[i] = max(kv_lens[i], q_lens[i]) - - total_q = q_lens.sum().item() + if causal: + kv_lens[i] = torch.randint(q_lens[i].item(), max_kv_len + 1, (1,)).item() + else: + kv_lens[i] = torch.randint(1, max_kv_len + 1, (1,)).item() - # Calculate pages needed for each sequence - pages_per_seq = (kv_lens + page_size - 1) // page_size # Ceiling division - total_pages_needed = pages_per_seq.sum().item() - - # Generate qo_indptr based on query lengths qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) - qo_indptr[1:] = torch.cumsum(q_lens, dim=0) + qo_indptr[1:] = torch.cumsum(q_lens.to(device), dim=0) - # Generate kv_indptr based on pages per sequence + pages_per_seq = (kv_lens + PAGE_SIZE - 1) // PAGE_SIZE kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) - kv_indptr[1:] = torch.cumsum(pages_per_seq, dim=0) - - # Generate kv_indices (page indices for each sequence) - kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) - - # Calculate last_page_len for each sequence - kv_last_page_len = ((kv_lens - 1) % page_size) + 1 + kv_indptr[1:] = torch.cumsum(pages_per_seq.to(device), dim=0) - # kv_len_arr stores the actual KV sequence lengths - kv_len_arr = kv_lens.clone() + total_q = qo_indptr[-1].item() + num_kv_pages = kv_indptr[-1].item() + kv_len_arr = kv_lens.to(device) - # Generate query tensors with Matrix Absorption dimensions - q_nope = torch.randn(total_q, num_qo_heads, head_dim_ckv, dtype=torch.bfloat16, device=device) - q_pe = torch.randn(total_q, num_qo_heads, head_dim_kpe, dtype=torch.bfloat16, device=device) + max_pages = num_kv_pages + 100 + kv_indices = torch.arange(num_kv_pages, dtype=torch.int32, device=device) + last_page_len = ((kv_lens - 1) % PAGE_SIZE) + 1 + last_page_len = last_page_len.to(torch.int32).to(device) - # Generate compressed KV and positional caches - num_pages = total_pages_needed + 100 - ckv_cache = torch.randn(num_pages, page_size, head_dim_ckv, dtype=torch.bfloat16, device=device) - kpe_cache = torch.randn(num_pages, page_size, head_dim_kpe, dtype=torch.bfloat16, device=device) + q_nope = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + q_pe = torch.randn(total_q, NUM_QO_HEADS, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) + ckv_cache = torch.randn(max_pages, PAGE_SIZE, HEAD_DIM_CKV, dtype=torch.bfloat16, device=device) + kpe_cache = torch.randn(max_pages, PAGE_SIZE, HEAD_DIM_KPE, dtype=torch.bfloat16, device=device) - # Generate attention parameters - # MLA uses head dimension before matrix absorption - sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe) + sm_scale = 1.0 / np.sqrt(128 + HEAD_DIM_KPE) sm_scale = torch.tensor(sm_scale, dtype=torch.float32, device=device) - # Convert causal to tensor - causal = torch.tensor(causal, dtype=torch.bool, device=device) - return { "q_nope": q_nope, "q_pe": q_pe, @@ -164,21 +62,21 @@ def generate_random_inputs( "qo_indptr": qo_indptr, "kv_indptr": kv_indptr, "kv_indices": kv_indices, - "kv_last_page_len": kv_last_page_len, "kv_len_arr": kv_len_arr, - "sm_scale": sm_scale, - "causal": causal, + "last_page_len": last_page_len, "q_lens": q_lens, "kv_lens": kv_lens, "total_q": total_q, + "sm_scale": sm_scale, + "causal": causal, } -def test_correctness(batch_size=4, max_q_len=32, max_kv_len=128, causal=True, atol=1e-2, rtol=5e-2): - """Test correctness of MLA paged prefill reference implementation against FlashInfer.""" +def test_correctness(batch_size=4, max_q_len=32, max_kv_len=256, causal=True, atol=1e-2, rtol=5e-2): + """Test correctness of MLA prefill reference implementation against FlashInfer.""" print(f"\n{'='*60}") print( - f"Testing batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}, causal={causal}" + f"Testing MLA Paged Prefill (ps64) batch_size={batch_size}, max_q_len={max_q_len}, max_kv_len={max_kv_len}" ) print(f"{'='*60}") @@ -187,34 +85,13 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=128, causal=True, at print("WARNING: CUDA not available, skipping test") return - # Constants from kernel definition - num_qo_heads = 16 - head_dim_ckv = 512 - head_dim_kpe = 64 - page_size = 64 - - # Generate inputs - inputs = generate_random_inputs( - batch_size, - max_q_len, - max_kv_len, - num_qo_heads, - head_dim_ckv, - head_dim_kpe, - page_size, - causal, - device, - ) + inputs = generate_random_inputs(batch_size, max_q_len, max_kv_len, causal, device) print(f"Generated query lengths: {inputs['q_lens'].cpu().numpy()}") print(f"Generated KV lengths: {inputs['kv_lens'].cpu().numpy()}") - print(f"Last page lengths: {inputs['kv_last_page_len'].cpu().numpy()}") - print(f"Total Q tokens: {inputs['total_q']}") - print(f"Total pages used: {inputs['kv_indices'].shape[0]}") - print(f"Causal: {inputs['causal'].item()}") + print(f"Last page lengths: {inputs['last_page_len'].cpu().numpy()}") - # Run reference implementation - print("\nRunning reference implementation...") + print("\nRunning reference implementation from definition...") ref_o, ref_lse = run( inputs["q_nope"], inputs["q_pe"], @@ -223,140 +100,58 @@ def test_correctness(batch_size=4, max_q_len=32, max_kv_len=128, causal=True, at inputs["qo_indptr"], inputs["kv_indptr"], inputs["kv_indices"], - inputs["kv_last_page_len"], + inputs["last_page_len"], inputs["sm_scale"], ) - # Setup FlashInfer print("\nSetting up FlashInfer...") workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - - # For paged prefill with Matrix Absorption, use BatchMLAPagedAttentionWrapper mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace_buffer, backend="auto") - # Plan the attention computation mla_wrapper.plan( qo_indptr=inputs["qo_indptr"], kv_indptr=inputs["kv_indptr"], kv_indices=inputs["kv_indices"], kv_len_arr=inputs["kv_len_arr"], - num_heads=num_qo_heads, - head_dim_ckv=head_dim_ckv, - head_dim_kpe=head_dim_kpe, - page_size=page_size, - causal=inputs["causal"].item(), + num_heads=NUM_QO_HEADS, + head_dim_ckv=HEAD_DIM_CKV, + head_dim_kpe=HEAD_DIM_KPE, + page_size=PAGE_SIZE, + causal=inputs["causal"], sm_scale=inputs["sm_scale"].item(), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, ) - # Run FlashInfer print("Running FlashInfer...") - flashinfer_o, flashinfer_lse = mla_wrapper.run( - q_nope=inputs["q_nope"], - q_pe=inputs["q_pe"], - ckv_cache=inputs["ckv_cache"], - kpe_cache=inputs["kpe_cache"], - return_lse=True, + fi_output, fi_lse = mla_wrapper.run( + inputs["q_nope"], inputs["q_pe"], inputs["ckv_cache"], inputs["kpe_cache"], return_lse=True ) - # Compare outputs print("\nComparing outputs...") - print(f"Reference output shape: {ref_o.shape}") - print(f"FlashInfer output shape: {flashinfer_o.shape}") - print(f"Reference LSE shape: {ref_lse.shape}") - print(f"FlashInfer LSE shape: {flashinfer_lse.shape}") - - # Check numerical accuracy - o_diff = torch.abs(ref_o - flashinfer_o) - lse_diff = torch.abs(ref_lse - flashinfer_lse) + output_metrics = compare_tensors(ref_o, fi_output, atol=atol, rtol=rtol) + print_comparison_metrics(output_metrics, tensor_name="Output tensor") - print(f"\nOutput max diff: {o_diff.max().item():.6f}") - print(f"Output mean diff: {o_diff.mean().item():.6f}") - print(f"LSE max diff: {lse_diff.max().item():.6f}") - print(f"LSE mean diff: {lse_diff.mean().item():.6f}") + lse_metrics = compare_tensors(ref_lse, fi_lse, atol=atol, rtol=rtol) + print_comparison_metrics(lse_metrics, tensor_name="LSE tensor") - # Check if outputs match within tolerance - output_close = torch.allclose(ref_o.float(), flashinfer_o.float(), atol=atol, rtol=rtol) - lse_close = torch.allclose(ref_lse, flashinfer_lse, atol=atol, rtol=rtol) - all_close = output_close and lse_close + all_close = output_metrics.all_close and lse_metrics.all_close if all_close: - print(f"\n✓ PASSED: Outputs and LSE match within tolerance (atol={atol}, rtol={rtol})") + print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})") else: - print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})") - - if not output_close: - flat_abs_diff = o_diff.flatten() - top_k = min(5, flat_abs_diff.numel()) - top_errors, top_indices = torch.topk(flat_abs_diff, top_k) - - print(f"\nTop {top_k} output tensor error locations:") - for i in range(top_k): - idx = top_indices[i].item() - _, num_qo_heads, head_dim_ckv = ref_o.shape - batch_idx = idx // (num_qo_heads * head_dim_ckv) - head_idx = (idx % (num_qo_heads * head_dim_ckv)) // head_dim_ckv - dim_idx = idx % head_dim_ckv - - ref_val = ref_o.flatten()[idx].item() - fi_val = flashinfer_o.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}, {dim_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_errors[i].item():.6e}" - ) - - if not lse_close: - flat_lse_diff = lse_diff.flatten() - top_k = min(5, flat_lse_diff.numel()) - top_lse_errors, top_lse_indices = torch.topk(flat_lse_diff, top_k) - - print(f"\nTop {top_k} LSE error locations:") - for i in range(top_k): - idx = top_lse_indices[i].item() - _, num_qo_heads = ref_lse.shape - batch_idx = idx // num_qo_heads - head_idx = idx % num_qo_heads - - ref_val = ref_lse.flatten()[idx].item() - fi_val = flashinfer_lse.flatten()[idx].item() - - print( - f" [{batch_idx}, {head_idx}]: " - f"ref={ref_val:.6f}, fi={fi_val:.6f}, diff={top_lse_errors[i].item():.6e}" - ) + print(f"\n✗ FAILED: Outputs differ beyond tolerance") return all_close def main(): - """Run comprehensive tests.""" - print("Testing Batch MLA Paged Prefill Reference Implementation (page_size=64)") - + print( + "Testing Batch MLA Paged Prefill Reference Implementation (page_size=64, from definition)" + ) test_configs = [(1, 16, 64, True), (4, 32, 128, True), (8, 64, 256, True)] - - passed = 0 - total = len(test_configs) - - for batch_size, max_q_len, max_kv_len, causal in test_configs: - try: - if test_correctness(batch_size, max_q_len, max_kv_len, causal): - passed += 1 - except Exception as e: - print(f"✗ Test failed with exception: {str(e)}") - import traceback - - traceback.print_exc() - - print(f"\n{'='*60}") - print(f"Summary: {passed}/{total} tests passed") - print(f"{'='*60}") - - if passed == total: - print("✓ All tests passed!") - else: - print(f"✗ {total - passed} tests failed") + passed = sum(1 for cfg in test_configs if test_correctness(*cfg)) + print(f"\n{'='*60}\nSummary: {passed}/{len(test_configs)} tests passed\n{'='*60}") if __name__ == "__main__": diff --git a/flashinfer_trace/tests/references/test_moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.py b/flashinfer_trace/tests/references/test_moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.py index 67dc70c9..ee276767 100644 --- a/flashinfer_trace/tests/references/test_moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.py +++ b/flashinfer_trace/tests/references/test_moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.py @@ -1,9 +1,20 @@ +""" +Test MoE FP8 Block Scale reference implementation against FlashInfer. + +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" + import json from pathlib import Path import torch from flashinfer.fused_moe import trtllm_fp8_block_scale_moe from safetensors.torch import load_file +from test_utils import get_reference_run + +# Load reference implementation from definition +run = get_reference_run("moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048") TRACE_ROOT = Path(__file__).resolve().parents[2] WORKLOAD_JSONL_PATH = ( @@ -14,177 +25,6 @@ ) -@torch.no_grad() -def run( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - hidden_states: torch.Tensor, - hidden_states_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm1_weights_scale: torch.Tensor, - gemm2_weights: torch.Tensor, - gemm2_weights_scale: torch.Tensor, - local_expert_offset: int, - routed_scaling_factor: float, -): - """ - • FP8 block-scale dequantization: float ≈ fp8 * scale - • DeepSeek-V3 no-aux routing: - s = sigmoid(logits) - s_with_bias = s + bias - group by n_group=8; per group take top-2 sum → pick topk_group=4 groups - on the kept groups, take global top_k=8 experts - combine with weights derived from s (without bias), normalized and - scaled by routed_scaling_factor - • Local computation: - only experts in [local_expert_offset, local_expert_offset + E_local) are - computed on this rank (GEMM1 → SwiGLU → GEMM2), then per-token weighted - accumulation. - """ - - # Fixed DeepSeek-V3/R1 geometry - H = 7168 - I = 2048 - E_local = gemm1_weights.shape[0] - - BLOCK = 128 - E_global = routing_logits.shape[1] - T = routing_logits.shape[0] - - assert H == 7168, "hidden_size must be 7168" - assert I == 2048, "intermediate_size must be 2048" - assert E_global == 256, "num_experts must be 256" - assert E_local == 32, "num_local_experts must be 32" - - # Routing constants - TOP_K = 8 - N_GROUP = 8 - TOPK_GROUP = 4 - - # Block counts - num_hidden_blocks = H // BLOCK # 56 - num_intermediate_blocks = I // BLOCK # 16 - num_gemm1_out_blocks = (2 * I) // BLOCK # 32 - - # Shape checks - assert hidden_states.shape == (T, H) - assert hidden_states_scale.shape == (num_hidden_blocks, T) - assert gemm1_weights.shape == (E_local, 2 * I, H) - assert gemm1_weights_scale.shape == (E_local, num_gemm1_out_blocks, num_hidden_blocks) - assert gemm2_weights.shape == (E_local, H, I) - assert gemm2_weights_scale.shape == (E_local, num_hidden_blocks, num_intermediate_blocks) - assert routing_bias.shape[-1] == E_global - - device = hidden_states.device - - # 1) FP8 block-scale dequantization - # hidden_states: [T, H], scale: [H/128, T] (transposed layout) - A_fp32 = hidden_states.to(torch.float32) - A_scale = hidden_states_scale.to(torch.float32) # [H/128, T] - A_scale_TH = A_scale.permute(1, 0).contiguous() # [T, H/128] - A_scale_expanded = ( - A_scale_TH.unsqueeze(-1) - .repeat(1, 1, BLOCK) # [T, H/128, 128] - .reshape(T, H) # [T, H] - .contiguous() - ) - A = A_fp32 * A_scale_expanded # [T, H] float32 - - # W13: [E_local, 2I, H], scale: [E_local, (2I)/128, H/128] - W13_fp32 = gemm1_weights.to(torch.float32) - S13 = gemm1_weights_scale.to(torch.float32) - S13_expanded = torch.repeat_interleave(S13, BLOCK, dim=1) # [E, 2I, H/128] - S13_expanded = torch.repeat_interleave(S13_expanded, BLOCK, dim=2) # [E, 2I, H] - W13 = W13_fp32 * S13_expanded # [E, 2I, H] float32 - - # W2: [E_local, H, I], scale: [E_local, H/128, I/128] - W2_fp32 = gemm2_weights.to(torch.float32) - S2 = gemm2_weights_scale.to(torch.float32) - S2_expanded = torch.repeat_interleave(S2, BLOCK, dim=1) # [E, H, I/128] - S2_expanded = torch.repeat_interleave(S2_expanded, BLOCK, dim=2) # [E, H, I] - W2 = W2_fp32 * S2_expanded # [E, H, I] float32 - - # 2) No-aux routing - logits = routing_logits.to(torch.float32) # [T, E_global] - bias = routing_bias.to(torch.float32).reshape(-1) # [E_global] - - # Sigmoid - s = 1.0 / (1.0 + torch.exp(-logits)) # [T, E] - s_with_bias = s + bias # [T, E] (broadcast) - - # Grouping - group_size = E_global // N_GROUP # 32 - s_wb_grouped = s_with_bias.view(T, N_GROUP, group_size) # [T, 8, 32] - - # Group scores = sum of top-2 values within each group - top2_vals, _ = torch.topk(s_wb_grouped, k=2, dim=2, largest=True, sorted=False) # [T, 8, 2] - group_scores = top2_vals.sum(dim=2) # [T, 8] - - # Select topk_group groups → group mask - _, group_idx = torch.topk( - group_scores, k=TOPK_GROUP, dim=1, largest=True, sorted=False - ) # [T, 4] - group_mask = torch.zeros_like(group_scores) # [T, 8] - group_mask.scatter_(1, group_idx, 1.0) - score_mask = ( - group_mask.unsqueeze(2).expand(T, N_GROUP, group_size).reshape(T, E_global) - ) # [T, E] - - # Global top-k (within kept groups), based on s_with_bias - neg_inf = torch.finfo(torch.float32).min - scores_pruned = s_with_bias.masked_fill(score_mask == 0, neg_inf) # [T, E] - _, topk_idx = torch.topk(scores_pruned, k=TOP_K, dim=1, largest=True, sorted=False) # [T, 8] - - # Combination weights: use s (without bias) for normalization - M = torch.zeros_like(s) # [T, E] - M.scatter_(1, topk_idx, 1.0) # 0/1 mask - weights = s * M # [T, E] - weights_sum = weights.sum(dim=1, keepdim=True) + 1e-20 - weights = (weights / weights_sum) * routed_scaling_factor # [T, E] - - # 3) Local expert compute and accumulation - output = torch.zeros((T, H), dtype=torch.float32, device=device) - - local_start = int(local_expert_offset) - - # For each local expert: find selected tokens, run GEMM1→SwiGLU→GEMM2, accumulate by weights - for le in range(E_local): - ge = local_start + le - if ge < 0 or ge >= E_global: - continue - - # Tokens that selected this global expert ge in their top-k - sel_mask_per_token = (topk_idx == ge).any(dim=1) # [T] bool - if not sel_mask_per_token.any(): - continue - - token_idx = torch.nonzero(sel_mask_per_token, as_tuple=False).squeeze(1) # [Tk] - Tk = token_idx.numel() - - # Gather inputs and weights for this expert - A_e = A.index_select(0, token_idx) # [Tk, H] - W13_e = W13[le] # [2I, H] - W2_e = W2[le] # [H, I] - - # GEMM1: [Tk, H] @ [H, 2I] = [Tk, 2I] - G1 = A_e.matmul(W13_e.t()) # [Tk, 2I] - - # SwiGLU: split and apply silu(x) = x / (1 + exp(-x)) - X1 = G1[:, :I] # [Tk, I] - X2 = G1[:, I:] # [Tk, I] - silu_X2 = X2 / (1.0 + torch.exp(-X2)) # [Tk, I] - C = silu_X2 * X1 # [Tk, I] - - # GEMM2: [Tk, I] @ [I, H] = [Tk, H] - O = C.matmul(W2_e.t()) # [Tk, H] - - # Accumulate with per-token routing weights for this expert - w_tok = weights.index_select(0, token_idx)[:, ge] # [Tk] - output.index_add_(0, token_idx, O * w_tok.unsqueeze(1)) # [Tk,H] * [Tk,1] - - return output.to(torch.bfloat16) - - # ----------------------------- # Helpers: FP8 block quantization (dequant scale semantics) - Vectorized # ----------------------------- diff --git a/flashinfer_trace/tests/references/test_rmsnorm_h128.py b/flashinfer_trace/tests/references/test_rmsnorm_h128.py index e9fc9080..e89d0c50 100644 --- a/flashinfer_trace/tests/references/test_rmsnorm_h128.py +++ b/flashinfer_trace/tests/references/test_rmsnorm_h128.py @@ -1,70 +1,32 @@ -import flashinfer -import torch - - -@torch.no_grad() -def run(input, weight, eps, residual=None): - """ - Reference implementation of RMSNorm with hidden_size=128. - - Args: - input: Input tensor of shape (B, 128) in bfloat16 - weight: Weight tensor of shape (128,) in bfloat16 - eps: Small epsilon value for numerical stability - residual: Optional residual tensor of shape (B, 128) in bfloat16 +""" +Test RMSNorm h128 reference implementation against FlashInfer. - Returns: - dict with 'output' key containing normalized output in bfloat16 - """ - batch_size, hidden_size = input.shape +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" - # Check constants - assert hidden_size == 128 - - # Perform computation in float32 for accuracy - orig_dtype = input.dtype - input_fp32 = input.to(torch.float32) - weight_fp32 = weight.to(torch.float32) - - if residual is not None: - residual_fp32 = residual.to(torch.float32) - input_fp32 = input_fp32 + residual_fp32 - - # Compute RMS - variance = input_fp32.pow(2).mean(dim=-1, keepdim=True) - rstd = torch.rsqrt(variance + eps) +import flashinfer +import torch +from test_utils import get_reference_run - # Apply normalization and weight - output = (input_fp32 * rstd) * weight_fp32 +# Load reference implementations from definitions +run_rmsnorm = get_reference_run("rmsnorm_h128") - # Convert back to original dtype - return {"output": output.to(orig_dtype)} +# Hidden size constant +HIDDEN_SIZE = 128 -def generate_random_inputs(batch_size, with_residual=True, device="cuda"): +def generate_random_inputs(batch_size, device="cuda"): """Generate random inputs for testing RMSNorm with hidden_size=128.""" + hidden_states = torch.randn(batch_size, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + weight = torch.randn(HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + return {"hidden_states": hidden_states, "weight": weight} - hidden_size = 128 - eps = 1e-6 # Common value for this configuration - # Generate input tensor - input = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - - # Generate weight tensor - weight = torch.randn(hidden_size, dtype=torch.bfloat16, device=device) - - # Generate residual if needed - residual = None - if with_residual: - residual = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - - return {"input": input, "weight": weight, "eps": eps, "residual": residual} - - -def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): +def test_correctness(batch_size=8, atol=8e-3, rtol=1e-2): """Test correctness of reference implementation against FlashInfer.""" print(f"\n{'='*60}") - print(f"Testing RMSNorm h128: batch_size={batch_size}, with_residual={with_residual}") + print(f"Testing RMSNorm h128: batch_size={batch_size}") print(f"{'='*60}") device = "cuda" if torch.cuda.is_available() else "cpu" @@ -73,43 +35,27 @@ def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): return False # Generate inputs - inputs = generate_random_inputs(batch_size, with_residual, device) + inputs = generate_random_inputs(batch_size, device) - print(f"Input shape: {inputs['input'].shape}") + print(f"Hidden states shape: {inputs['hidden_states'].shape}") print(f"Weight shape: {inputs['weight'].shape}") - print(f"Epsilon: {inputs['eps']}") - print(f"Has residual: {inputs['residual'] is not None}") - - # Run reference implementation - print("\nRunning reference implementation...") - ref_output = run( - inputs["input"].clone(), - inputs["weight"], - inputs["eps"], - inputs["residual"].clone() if inputs["residual"] is not None else None, - ) + + # Run reference implementation from definition + print("\nRunning reference implementation from definition...") + ref_output = run_rmsnorm(inputs["hidden_states"].clone(), inputs["weight"]) # Run FlashInfer implementation print("Running FlashInfer implementation...") - input_fi = inputs["input"].clone().contiguous() + input_fi = inputs["hidden_states"].clone().contiguous() weight_fi = inputs["weight"].contiguous() - - if inputs["residual"] is not None: - residual_fi = inputs["residual"].clone().contiguous() - # Use fused kernel for residual case - flashinfer.norm.fused_add_rmsnorm(input_fi, residual_fi, weight_fi, inputs["eps"]) - fi_output = {"output": input_fi} - else: - # Standard RMSNorm without residual - fi_out = flashinfer.norm.rmsnorm(input_fi, weight_fi, eps=inputs["eps"]) - fi_output = {"output": fi_out} + fi_output = flashinfer.norm.rmsnorm(input_fi, weight_fi, eps=1e-6) # Compare outputs print("\nComparing outputs...") # Convert to float32 for comparison - ref_out_f32 = ref_output["output"].float() - fi_out_f32 = fi_output["output"].float() + ref_out_f32 = ref_output.float() + fi_out_f32 = fi_output.float() # Compute errors abs_diff = torch.abs(ref_out_f32 - fi_out_f32) @@ -139,18 +85,10 @@ def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): def main(): """Run comprehensive tests for RMSNorm h128.""" - print("Testing RMSNorm h128 Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, with_residual) - (1, True), # Single batch with residual - (1, False), # Single batch without residual - (4, True), # Small batch with residual - (8, True), # Medium batch with residual - (16, True), # Large batch with residual - (32, True), # Very large batch with residual - ] + print("Testing RMSNorm h128 Reference Implementation (from definition)") + + # Test different batch sizes + test_configs = [1, 4, 8, 16, 32] passed = 0 total = len(test_configs) @@ -159,9 +97,9 @@ def main(): atol = 8e-3 # 0.8% absolute tolerance rtol = 1e-2 # 1% relative tolerance - for batch_size, with_residual in test_configs: + for batch_size in test_configs: try: - if test_correctness(batch_size, with_residual, atol, rtol): + if test_correctness(batch_size, atol, rtol): passed += 1 except Exception as e: print(f"✗ Test failed with exception: {str(e)}") diff --git a/flashinfer_trace/tests/references/test_rmsnorm_h2048.py b/flashinfer_trace/tests/references/test_rmsnorm_h2048.py index 87427b09..816caf25 100644 --- a/flashinfer_trace/tests/references/test_rmsnorm_h2048.py +++ b/flashinfer_trace/tests/references/test_rmsnorm_h2048.py @@ -1,70 +1,32 @@ -import flashinfer -import torch - - -@torch.no_grad() -def run(input, weight, eps, residual=None): - """ - Reference implementation of RMSNorm with hidden_size=2048. - - Args: - input: Input tensor of shape (B, 2048) in bfloat16 - weight: Weight tensor of shape (2048,) in bfloat16 - eps: Small epsilon value for numerical stability - residual: Optional residual tensor of shape (B, 2048) in bfloat16 +""" +Test RMSNorm h2048 reference implementation against FlashInfer. - Returns: - dict with 'output' key containing normalized output in bfloat16 - """ - batch_size, hidden_size = input.shape +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" - # Check constants - assert hidden_size == 2048 - - # Perform computation in float32 for accuracy - orig_dtype = input.dtype - input_fp32 = input.to(torch.float32) - weight_fp32 = weight.to(torch.float32) - - if residual is not None: - residual_fp32 = residual.to(torch.float32) - input_fp32 = input_fp32 + residual_fp32 - - # Compute RMS - variance = input_fp32.pow(2).mean(dim=-1, keepdim=True) - rstd = torch.rsqrt(variance + eps) +import flashinfer +import torch +from test_utils import get_reference_run - # Apply normalization and weight - output = (input_fp32 * rstd) * weight_fp32 +# Load reference implementations from definitions +run_rmsnorm = get_reference_run("rmsnorm_h2048") - # Convert back to original dtype - return {"output": output.to(orig_dtype)} +# Hidden size constant +HIDDEN_SIZE = 2048 -def generate_random_inputs(batch_size, with_residual=True, device="cuda"): +def generate_random_inputs(batch_size, device="cuda"): """Generate random inputs for testing RMSNorm with hidden_size=2048.""" + hidden_states = torch.randn(batch_size, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + weight = torch.randn(HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + return {"hidden_states": hidden_states, "weight": weight} - hidden_size = 2048 - eps = 1e-6 # Common value for this configuration - # Generate input tensor - input = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - - # Generate weight tensor - weight = torch.randn(hidden_size, dtype=torch.bfloat16, device=device) - - # Generate residual if needed - residual = None - if with_residual: - residual = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - - return {"input": input, "weight": weight, "eps": eps, "residual": residual} - - -def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): +def test_correctness(batch_size=8, atol=8e-3, rtol=1e-2): """Test correctness of reference implementation against FlashInfer.""" print(f"\n{'='*60}") - print(f"Testing RMSNorm h2048: batch_size={batch_size}, with_residual={with_residual}") + print(f"Testing RMSNorm h2048: batch_size={batch_size}") print(f"{'='*60}") device = "cuda" if torch.cuda.is_available() else "cpu" @@ -73,43 +35,27 @@ def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): return False # Generate inputs - inputs = generate_random_inputs(batch_size, with_residual, device) + inputs = generate_random_inputs(batch_size, device) - print(f"Input shape: {inputs['input'].shape}") + print(f"Hidden states shape: {inputs['hidden_states'].shape}") print(f"Weight shape: {inputs['weight'].shape}") - print(f"Epsilon: {inputs['eps']}") - print(f"Has residual: {inputs['residual'] is not None}") - - # Run reference implementation - print("\nRunning reference implementation...") - ref_output = run( - inputs["input"].clone(), - inputs["weight"], - inputs["eps"], - inputs["residual"].clone() if inputs["residual"] is not None else None, - ) + + # Run reference implementation from definition + print("\nRunning reference implementation from definition...") + ref_output = run_rmsnorm(inputs["hidden_states"].clone(), inputs["weight"]) # Run FlashInfer implementation print("Running FlashInfer implementation...") - input_fi = inputs["input"].clone().contiguous() + input_fi = inputs["hidden_states"].clone().contiguous() weight_fi = inputs["weight"].contiguous() - - if inputs["residual"] is not None: - residual_fi = inputs["residual"].clone().contiguous() - # Use fused kernel for residual case - flashinfer.norm.fused_add_rmsnorm(input_fi, residual_fi, weight_fi, inputs["eps"]) - fi_output = {"output": input_fi} - else: - # Standard RMSNorm without residual - fi_out = flashinfer.norm.rmsnorm(input_fi, weight_fi, eps=inputs["eps"]) - fi_output = {"output": fi_out} + fi_output = flashinfer.norm.rmsnorm(input_fi, weight_fi, eps=1e-6) # Compare outputs print("\nComparing outputs...") # Convert to float32 for comparison - ref_out_f32 = ref_output["output"].float() - fi_out_f32 = fi_output["output"].float() + ref_out_f32 = ref_output.float() + fi_out_f32 = fi_output.float() # Compute errors abs_diff = torch.abs(ref_out_f32 - fi_out_f32) @@ -139,18 +85,10 @@ def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): def main(): """Run comprehensive tests for RMSNorm h2048.""" - print("Testing RMSNorm h2048 Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, with_residual) - (1, True), # Single batch with residual - (1, False), # Single batch without residual - (4, True), # Small batch with residual - (8, True), # Medium batch with residual - (16, True), # Large batch with residual - (32, True), # Very large batch with residual - ] + print("Testing RMSNorm h2048 Reference Implementation (from definition)") + + # Test different batch sizes + test_configs = [1, 4, 8, 16, 32] passed = 0 total = len(test_configs) @@ -159,9 +97,9 @@ def main(): atol = 8e-3 # 0.8% absolute tolerance rtol = 1e-2 # 1% relative tolerance - for batch_size, with_residual in test_configs: + for batch_size in test_configs: try: - if test_correctness(batch_size, with_residual, atol, rtol): + if test_correctness(batch_size, atol, rtol): passed += 1 except Exception as e: print(f"✗ Test failed with exception: {str(e)}") diff --git a/flashinfer_trace/tests/references/test_rmsnorm_h4096.py b/flashinfer_trace/tests/references/test_rmsnorm_h4096.py index 580b9fb7..232fc3ea 100644 --- a/flashinfer_trace/tests/references/test_rmsnorm_h4096.py +++ b/flashinfer_trace/tests/references/test_rmsnorm_h4096.py @@ -1,70 +1,32 @@ -import flashinfer -import torch - - -@torch.no_grad() -def run(input, weight, eps, residual=None): - """ - Reference implementation of RMSNorm with hidden_size=4096. - - Args: - input: Input tensor of shape (B, 4096) in bfloat16 - weight: Weight tensor of shape (4096,) in bfloat16 - eps: Small epsilon value for numerical stability - residual: Optional residual tensor of shape (B, 4096) in bfloat16 +""" +Test RMSNorm h4096 reference implementation against FlashInfer. - Returns: - dict with 'output' key containing normalized output in bfloat16 - """ - batch_size, hidden_size = input.shape +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" - # Check constants - assert hidden_size == 4096 - - # Perform computation in float32 for accuracy - orig_dtype = input.dtype - input_fp32 = input.to(torch.float32) - weight_fp32 = weight.to(torch.float32) - - if residual is not None: - residual_fp32 = residual.to(torch.float32) - input_fp32 = input_fp32 + residual_fp32 - - # Compute RMS - variance = input_fp32.pow(2).mean(dim=-1, keepdim=True) - rstd = torch.rsqrt(variance + eps) +import flashinfer +import torch +from test_utils import get_reference_run - # Apply normalization and weight - output = (input_fp32 * rstd) * weight_fp32 +# Load reference implementations from definitions +run_rmsnorm = get_reference_run("rmsnorm_h4096") - # Convert back to original dtype - return {"output": output.to(orig_dtype)} +# Hidden size constant +HIDDEN_SIZE = 4096 -def generate_random_inputs(batch_size, with_residual=True, device="cuda"): +def generate_random_inputs(batch_size, device="cuda"): """Generate random inputs for testing RMSNorm with hidden_size=4096.""" + hidden_states = torch.randn(batch_size, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + weight = torch.randn(HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + return {"hidden_states": hidden_states, "weight": weight} - hidden_size = 4096 - eps = 1e-5 # Common value for this configuration - # Generate input tensor - input = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - - # Generate weight tensor - weight = torch.randn(hidden_size, dtype=torch.bfloat16, device=device) - - # Generate residual if needed - residual = None - if with_residual: - residual = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - - return {"input": input, "weight": weight, "eps": eps, "residual": residual} - - -def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): +def test_correctness(batch_size=8, atol=8e-3, rtol=1e-2): """Test correctness of reference implementation against FlashInfer.""" print(f"\n{'='*60}") - print(f"Testing RMSNorm h4096: batch_size={batch_size}, with_residual={with_residual}") + print(f"Testing RMSNorm h4096: batch_size={batch_size}") print(f"{'='*60}") device = "cuda" if torch.cuda.is_available() else "cpu" @@ -73,43 +35,27 @@ def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): return False # Generate inputs - inputs = generate_random_inputs(batch_size, with_residual, device) + inputs = generate_random_inputs(batch_size, device) - print(f"Input shape: {inputs['input'].shape}") + print(f"Hidden states shape: {inputs['hidden_states'].shape}") print(f"Weight shape: {inputs['weight'].shape}") - print(f"Epsilon: {inputs['eps']}") - print(f"Has residual: {inputs['residual'] is not None}") - - # Run reference implementation - print("\nRunning reference implementation...") - ref_output = run( - inputs["input"].clone(), - inputs["weight"], - inputs["eps"], - inputs["residual"].clone() if inputs["residual"] is not None else None, - ) + + # Run reference implementation from definition + print("\nRunning reference implementation from definition...") + ref_output = run_rmsnorm(inputs["hidden_states"].clone(), inputs["weight"]) # Run FlashInfer implementation print("Running FlashInfer implementation...") - input_fi = inputs["input"].clone().contiguous() + input_fi = inputs["hidden_states"].clone().contiguous() weight_fi = inputs["weight"].contiguous() - - if inputs["residual"] is not None: - residual_fi = inputs["residual"].clone().contiguous() - # Use fused kernel for residual case - flashinfer.norm.fused_add_rmsnorm(input_fi, residual_fi, weight_fi, inputs["eps"]) - fi_output = {"output": input_fi} - else: - # Standard RMSNorm without residual - fi_out = flashinfer.norm.rmsnorm(input_fi, weight_fi, eps=inputs["eps"]) - fi_output = {"output": fi_out} + fi_output = flashinfer.norm.rmsnorm(input_fi, weight_fi, eps=1e-5) # Compare outputs print("\nComparing outputs...") # Convert to float32 for comparison - ref_out_f32 = ref_output["output"].float() - fi_out_f32 = fi_output["output"].float() + ref_out_f32 = ref_output.float() + fi_out_f32 = fi_output.float() # Compute errors abs_diff = torch.abs(ref_out_f32 - fi_out_f32) @@ -139,18 +85,10 @@ def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): def main(): """Run comprehensive tests for RMSNorm h4096.""" - print("Testing RMSNorm h4096 Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, with_residual) - (1, True), # Single batch with residual - (1, False), # Single batch without residual - (4, True), # Small batch with residual - (8, True), # Medium batch with residual - (16, True), # Large batch with residual - (32, True), # Very large batch with residual - ] + print("Testing RMSNorm h4096 Reference Implementation (from definition)") + + # Test different batch sizes + test_configs = [1, 4, 8, 16, 32] passed = 0 total = len(test_configs) @@ -159,9 +97,9 @@ def main(): atol = 8e-3 # 0.8% absolute tolerance rtol = 1e-2 # 1% relative tolerance - for batch_size, with_residual in test_configs: + for batch_size in test_configs: try: - if test_correctness(batch_size, with_residual, atol, rtol): + if test_correctness(batch_size, atol, rtol): passed += 1 except Exception as e: print(f"✗ Test failed with exception: {str(e)}") diff --git a/flashinfer_trace/tests/references/test_rmsnorm_h7168.py b/flashinfer_trace/tests/references/test_rmsnorm_h7168.py index 926e07dc..81d3feb2 100644 --- a/flashinfer_trace/tests/references/test_rmsnorm_h7168.py +++ b/flashinfer_trace/tests/references/test_rmsnorm_h7168.py @@ -1,70 +1,32 @@ -import flashinfer -import torch - - -@torch.no_grad() -def run(input, weight, eps, residual=None): - """ - Reference implementation of RMSNorm with hidden_size=7168. - - Args: - input: Input tensor of shape (B, 7168) in bfloat16 - weight: Weight tensor of shape (7168,) in bfloat16 - eps: Small epsilon value for numerical stability - residual: Optional residual tensor of shape (B, 7168) in bfloat16 +""" +Test RMSNorm h7168 reference implementation against FlashInfer. - Returns: - dict with 'output' key containing normalized output in bfloat16 - """ - batch_size, hidden_size = input.shape +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation. +""" - # Check constants - assert hidden_size == 7168 - - # Perform computation in float32 for accuracy - orig_dtype = input.dtype - input_fp32 = input.to(torch.float32) - weight_fp32 = weight.to(torch.float32) - - if residual is not None: - residual_fp32 = residual.to(torch.float32) - input_fp32 = input_fp32 + residual_fp32 - - # Compute RMS - variance = input_fp32.pow(2).mean(dim=-1, keepdim=True) - rstd = torch.rsqrt(variance + eps) +import flashinfer +import torch +from test_utils import get_reference_run - # Apply normalization and weight - output = (input_fp32 * rstd) * weight_fp32 +# Load reference implementations from definitions +run_rmsnorm = get_reference_run("rmsnorm_h7168") - # Convert back to original dtype - return {"output": output.to(orig_dtype)} +# Hidden size constant +HIDDEN_SIZE = 7168 -def generate_random_inputs(batch_size, with_residual=True, device="cuda"): +def generate_random_inputs(batch_size, device="cuda"): """Generate random inputs for testing RMSNorm with hidden_size=7168.""" + hidden_states = torch.randn(batch_size, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + weight = torch.randn(HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + return {"hidden_states": hidden_states, "weight": weight} - hidden_size = 7168 - eps = 1e-6 # Common value for this configuration - # Generate input tensor - input = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - - # Generate weight tensor - weight = torch.randn(hidden_size, dtype=torch.bfloat16, device=device) - - # Generate residual if needed - residual = None - if with_residual: - residual = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device) - - return {"input": input, "weight": weight, "eps": eps, "residual": residual} - - -def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): +def test_correctness(batch_size=8, atol=8e-3, rtol=1e-2): """Test correctness of reference implementation against FlashInfer.""" print(f"\n{'='*60}") - print(f"Testing RMSNorm h7168: batch_size={batch_size}, with_residual={with_residual}") + print(f"Testing RMSNorm h7168: batch_size={batch_size}") print(f"{'='*60}") device = "cuda" if torch.cuda.is_available() else "cpu" @@ -73,43 +35,27 @@ def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): return False # Generate inputs - inputs = generate_random_inputs(batch_size, with_residual, device) + inputs = generate_random_inputs(batch_size, device) - print(f"Input shape: {inputs['input'].shape}") + print(f"Hidden states shape: {inputs['hidden_states'].shape}") print(f"Weight shape: {inputs['weight'].shape}") - print(f"Epsilon: {inputs['eps']}") - print(f"Has residual: {inputs['residual'] is not None}") - - # Run reference implementation - print("\nRunning reference implementation...") - ref_output = run( - inputs["input"].clone(), - inputs["weight"], - inputs["eps"], - inputs["residual"].clone() if inputs["residual"] is not None else None, - ) + + # Run reference implementation from definition + print("\nRunning reference implementation from definition...") + ref_output = run_rmsnorm(inputs["hidden_states"].clone(), inputs["weight"]) # Run FlashInfer implementation print("Running FlashInfer implementation...") - input_fi = inputs["input"].clone().contiguous() + input_fi = inputs["hidden_states"].clone().contiguous() weight_fi = inputs["weight"].contiguous() - - if inputs["residual"] is not None: - residual_fi = inputs["residual"].clone().contiguous() - # Use fused kernel for residual case - flashinfer.norm.fused_add_rmsnorm(input_fi, residual_fi, weight_fi, inputs["eps"]) - fi_output = {"output": input_fi} - else: - # Standard RMSNorm without residual - fi_out = flashinfer.norm.rmsnorm(input_fi, weight_fi, eps=inputs["eps"]) - fi_output = {"output": fi_out} + fi_output = flashinfer.norm.rmsnorm(input_fi, weight_fi, eps=1e-6) # Compare outputs print("\nComparing outputs...") # Convert to float32 for comparison - ref_out_f32 = ref_output["output"].float() - fi_out_f32 = fi_output["output"].float() + ref_out_f32 = ref_output.float() + fi_out_f32 = fi_output.float() # Compute errors abs_diff = torch.abs(ref_out_f32 - fi_out_f32) @@ -139,18 +85,10 @@ def test_correctness(batch_size=8, with_residual=True, atol=8e-3, rtol=1e-2): def main(): """Run comprehensive tests for RMSNorm h7168.""" - print("Testing RMSNorm h7168 Reference Implementation") - - # Test different configurations - test_configs = [ - # (batch_size, with_residual) - (1, True), # Single batch with residual - (1, False), # Single batch without residual - (4, True), # Small batch with residual - (8, True), # Medium batch with residual - (16, True), # Large batch with residual - (32, True), # Very large batch with residual - ] + print("Testing RMSNorm h7168 Reference Implementation (from definition)") + + # Test different batch sizes + test_configs = [1, 4, 8, 16, 32] passed = 0 total = len(test_configs) @@ -159,9 +97,9 @@ def main(): atol = 8e-3 # 0.8% absolute tolerance rtol = 1e-2 # 1% relative tolerance - for batch_size, with_residual in test_configs: + for batch_size in test_configs: try: - if test_correctness(batch_size, with_residual, atol, rtol): + if test_correctness(batch_size, atol, rtol): passed += 1 except Exception as e: print(f"✗ Test failed with exception: {str(e)}") diff --git a/flashinfer_trace/tests/references/test_top_k_sampling_from_probs.py b/flashinfer_trace/tests/references/test_top_k_sampling_from_probs.py index a5cbf822..95833595 100644 --- a/flashinfer_trace/tests/references/test_top_k_sampling_from_probs.py +++ b/flashinfer_trace/tests/references/test_top_k_sampling_from_probs.py @@ -1,44 +1,23 @@ -import flashinfer -import torch - - -@torch.no_grad() -def run(probs, top_k): - batch_size, vocab_size = probs.shape - device = probs.device +""" +Test top_k_sampling_from_probs reference implementation against FlashInfer. - # Check constants - # assert vocab_size == 128256 - - probs = probs.to(torch.float32) - samples = torch.empty(batch_size, dtype=torch.int64, device=device) - - for i in range(batch_size): - row = probs[i] - k = int(top_k[i].item()) +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation in terms of distribution. +""" - # No filtering on invalid k - if 0 < k < vocab_size: - idx_sorted = torch.argsort(row, descending=True) - keep_idx = idx_sorted[:k] - - filtered = torch.zeros_like(row) - filtered[keep_idx] = row[keep_idx] - - row = filtered / filtered.sum() - - samples[i] = torch.multinomial(row, 1, replacement=True).squeeze(0) +import flashinfer +import torch +from test_utils import get_reference_run - return samples +# Load reference implementation from definition (use v128256 as default) +run = get_reference_run("top_k_sampling_from_probs_v128256") def generate_random_inputs(batch_size, vocab_size=128256, distribution="normal", device="cuda"): """Generate random test inputs.""" - # Generate probabilities if distribution == "normal": logits = torch.randn(batch_size, vocab_size, device=device) elif distribution == "peaked": - # Create peaked distribution logits = torch.randn(batch_size, vocab_size, device=device) * 0.1 peak_indices = torch.randint(0, vocab_size, (batch_size,), device=device) for i in range(batch_size): @@ -48,10 +27,7 @@ def generate_random_inputs(batch_size, vocab_size=128256, distribution="normal", else: raise ValueError(f"Unknown distribution: {distribution}") - # Convert to probabilities probs = torch.softmax(logits, dim=-1).to(torch.float32) - - # Generate varying top_k values top_k = torch.randint( 10, min(500, vocab_size // 2), (batch_size,), dtype=torch.int32, device=device ) @@ -69,10 +45,8 @@ def test_correctness(batch_size=8, vocab_size=128256, num_trials=10000): device = "cuda" torch.manual_seed(42) - # Generate inputs probs, top_k = generate_random_inputs(batch_size, vocab_size, "peaked", device) - # Count frequencies for both implementations ref_counter = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device=device) fi_counter = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device=device) @@ -81,26 +55,21 @@ def test_correctness(batch_size=8, vocab_size=128256, num_trials=10000): if trial % progress_interval == 0: print(f" Trial {trial}/{num_trials}...") - # Reference implementation torch.manual_seed(42 + trial) ref_samples = run(probs, top_k) for i in range(batch_size): ref_counter[i, ref_samples[i]] += 1 - # FlashInfer implementation torch.manual_seed(42 + trial) fi_samples = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k) for i in range(batch_size): fi_counter[i, fi_samples[i]] += 1 - # Calculate frequencies ref_freq = ref_counter.float() / num_trials fi_freq = fi_counter.float() / num_trials - # Calculate cosine similarity similarities = [] for i in range(batch_size): - # Only compare tokens that were sampled at least once mask = (ref_freq[i] > 0) | (fi_freq[i] > 0) if mask.sum() > 0: ref = ref_freq[i][mask] @@ -112,7 +81,6 @@ def test_correctness(batch_size=8, vocab_size=128256, num_trials=10000): avg_similarity = sum(similarities) / len(similarities) print(f"\n Average cosine similarity: {avg_similarity:.4f}") - # Check similarity assert avg_similarity > 0.95, f"Implementations diverge too much: {avg_similarity:.4f} < 0.95" print(" Correctness test passed!") @@ -121,13 +89,11 @@ def test_correctness(batch_size=8, vocab_size=128256, num_trials=10000): def main(): """Run comprehensive tests for top_k_sampling_from_probs.""" - print("Testing Top-K Sampling from Probabilities") + print("Testing Top-K Sampling from Probabilities (from definition)") all_passed = True - # Test correctness by comparing with FlashInfer try: - # Test with different configurations test_configs = [(2, 128256, 10000), (4, 129280, 10000), (8, 151936, 10000)] for batch_size, vocab_size, num_trials in test_configs: @@ -138,7 +104,6 @@ def main(): print(f"Correctness test failed: {e}") all_passed = False - # Summary print(f"\n{'=' * 60}") if all_passed: print("All tests passed!") diff --git a/flashinfer_trace/tests/references/test_top_k_top_p_sampling_from_probs.py b/flashinfer_trace/tests/references/test_top_k_top_p_sampling_from_probs.py index ab8af3ce..b2f5f516 100644 --- a/flashinfer_trace/tests/references/test_top_k_top_p_sampling_from_probs.py +++ b/flashinfer_trace/tests/references/test_top_k_top_p_sampling_from_probs.py @@ -1,63 +1,23 @@ -import flashinfer -import torch - - -@torch.no_grad() -def run(probs, top_k, top_p): - batch_size, vocab_size = probs.shape - device = probs.device - - # Check constants - # assert vocab_size == 128256 - - probs = probs.to(torch.float32) - samples = torch.empty(batch_size, dtype=torch.int64, device=device) - - for i in range(batch_size): - row = probs[i] - k = int(top_k[i].item()) - p = float(top_p[i].item()) +""" +Test top_k_top_p_sampling_from_probs reference implementation against FlashInfer. - # Apply top-k filtering - if 0 < k < vocab_size: - idx_sorted = torch.argsort(row, descending=True) - keep_idx_k = idx_sorted[:k] - filtered_k = torch.zeros_like(row) - filtered_k[keep_idx_k] = row[keep_idx_k] - row = filtered_k / filtered_k.sum() +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation in terms of distribution. +""" - # Then apply top-p filtering - if p <= 0.0: - samples[i] = torch.argmax(row).to(torch.int64) - continue - - if p < 1.0: - vals, idx = torch.sort(row, descending=True) - cdf = torch.cumsum(vals, dim=0) - - to_remove = cdf > p - if vocab_size > 1: - to_remove[1:] = to_remove[:-1].clone() - to_remove[0] = False - - keep_idx_p = idx[~to_remove] - filtered_p = torch.zeros_like(row) - filtered_p[keep_idx_p] = row[keep_idx_p] - row = filtered_p / filtered_p.sum() - - # sample - samples[i] = torch.multinomial(row, 1, replacement=True).squeeze(0) +import flashinfer +import torch +from test_utils import get_reference_run - return samples +# Load reference implementation from definition (use v128256 as default) +run = get_reference_run("top_k_top_p_sampling_from_probs_v128256") def generate_random_inputs(batch_size, vocab_size=128256, distribution="normal", device="cuda"): """Generate random test inputs.""" - # Generate probabilities if distribution == "normal": logits = torch.randn(batch_size, vocab_size, device=device) elif distribution == "peaked": - # Create peaked distribution logits = torch.randn(batch_size, vocab_size, device=device) * 0.1 peak_indices = torch.randint(0, vocab_size, (batch_size,), device=device) for i in range(batch_size): @@ -67,14 +27,11 @@ def generate_random_inputs(batch_size, vocab_size=128256, distribution="normal", else: raise ValueError(f"Unknown distribution: {distribution}") - # Convert to probabilities probs = torch.softmax(logits, dim=-1).to(torch.float32) - - # Generate varying top_k and top_p values top_k = torch.randint( 10, min(500, vocab_size // 2), (batch_size,), dtype=torch.int32, device=device ) - top_p = torch.rand(batch_size, device=device) * 0.8 + 0.1 # Range [0.1, 0.9] + top_p = torch.rand(batch_size, dtype=torch.float32, device=device) * 0.5 + 0.5 # 0.5-1.0 return probs, top_k, top_p @@ -89,10 +46,8 @@ def test_correctness(batch_size=8, vocab_size=128256, num_trials=10000): device = "cuda" torch.manual_seed(42) - # Generate inputs probs, top_k, top_p = generate_random_inputs(batch_size, vocab_size, "peaked", device) - # Count frequencies for both implementations ref_counter = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device=device) fi_counter = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device=device) @@ -101,28 +56,21 @@ def test_correctness(batch_size=8, vocab_size=128256, num_trials=10000): if trial % progress_interval == 0: print(f" Trial {trial}/{num_trials}...") - # Reference implementation torch.manual_seed(42 + trial) ref_samples = run(probs, top_k, top_p) for i in range(batch_size): ref_counter[i, ref_samples[i]] += 1 - # FlashInfer implementation torch.manual_seed(42 + trial) - fi_samples = flashinfer.sampling.top_k_top_p_sampling_from_probs( - probs, top_k, top_p, filter_apply_order="top_k_first" - ) + fi_samples = flashinfer.sampling.top_k_top_p_sampling_from_probs(probs, top_k, top_p) for i in range(batch_size): fi_counter[i, fi_samples[i]] += 1 - # Calculate frequencies ref_freq = ref_counter.float() / num_trials fi_freq = fi_counter.float() / num_trials - # Calculate cosine similarity similarities = [] for i in range(batch_size): - # Only compare tokens that were sampled at least once mask = (ref_freq[i] > 0) | (fi_freq[i] > 0) if mask.sum() > 0: ref = ref_freq[i][mask] @@ -134,7 +82,6 @@ def test_correctness(batch_size=8, vocab_size=128256, num_trials=10000): avg_similarity = sum(similarities) / len(similarities) print(f"\n Average cosine similarity: {avg_similarity:.4f}") - # Check similarity assert avg_similarity > 0.95, f"Implementations diverge too much: {avg_similarity:.4f} < 0.95" print(" Correctness test passed!") @@ -143,13 +90,11 @@ def test_correctness(batch_size=8, vocab_size=128256, num_trials=10000): def main(): """Run comprehensive tests for top_k_top_p_sampling_from_probs.""" - print("Testing Combined Top-K Top-P Sampling from Probabilities") + print("Testing Top-K + Top-P Sampling from Probabilities (from definition)") all_passed = True - # Test correctness by comparing with FlashInfer try: - # Test with different configurations test_configs = [(2, 128256, 10000), (4, 129280, 10000), (8, 151936, 10000)] for batch_size, vocab_size, num_trials in test_configs: @@ -160,7 +105,6 @@ def main(): print(f"Correctness test failed: {e}") all_passed = False - # Summary print(f"\n{'=' * 60}") if all_passed: print("All tests passed!") diff --git a/flashinfer_trace/tests/references/test_top_p_sampling_from_probs.py b/flashinfer_trace/tests/references/test_top_p_sampling_from_probs.py index 3dbecb1d..e992217d 100644 --- a/flashinfer_trace/tests/references/test_top_p_sampling_from_probs.py +++ b/flashinfer_trace/tests/references/test_top_p_sampling_from_probs.py @@ -1,55 +1,23 @@ -import flashinfer -import torch - - -@torch.no_grad() -def run(probs, top_p): - batch_size, vocab_size = probs.shape - device = probs.device - - # Check constants - # assert vocab_size == 129280 - - probs = probs.to(torch.float32) - out = torch.empty(batch_size, dtype=torch.int64, device=device) - - for i in range(batch_size): - row = probs[i] - p = float(top_p[i].item()) - - if p <= 0.0: - # Degenerate to argmax - out[i] = torch.argmax(row).to(torch.int64) - continue - - if p < 1.0: - vals, idx = torch.sort(row, descending=True) - cdf = torch.cumsum(vals, dim=0) - - # Shift mask to keep the first token that crosses p - to_remove = cdf > p - to_remove[1:] = to_remove[:-1].clone() - to_remove[0] = False - keep = ~to_remove - keep_idx = idx[keep] +""" +Test top_p_sampling_from_probs reference implementation against FlashInfer. - # Build filtered distribution in original index space - filtered = torch.zeros_like(row) - filtered[keep_idx] = row[keep_idx] - row = filtered / filtered.sum() +This test validates that the reference implementation from the definition +matches the FlashInfer kernel implementation in terms of distribution. +""" - out[i] = torch.multinomial(row, 1, replacement=True).squeeze(0) +import flashinfer +import torch +from test_utils import get_reference_run - return out +# Load reference implementation from definition (use v128256 as default) +run = get_reference_run("top_p_sampling_from_probs_v128256") def generate_random_inputs(batch_size, vocab_size=128256, distribution="normal", device="cuda"): """Generate random test inputs.""" - # Generate probabilities if distribution == "normal": logits = torch.randn(batch_size, vocab_size, device=device) elif distribution == "peaked": - # Create peaked distribution logits = torch.randn(batch_size, vocab_size, device=device) * 0.1 peak_indices = torch.randint(0, vocab_size, (batch_size,), device=device) for i in range(batch_size): @@ -59,30 +27,24 @@ def generate_random_inputs(batch_size, vocab_size=128256, distribution="normal", else: raise ValueError(f"Unknown distribution: {distribution}") - # Convert to probabilities probs = torch.softmax(logits, dim=-1).to(torch.float32) - - # Generate varying top_p values - top_p = torch.rand(batch_size, device=device) * 0.8 + 0.1 # Range [0.1, 0.9] + top_p = torch.rand(batch_size, dtype=torch.float32, device=device) * 0.5 + 0.5 # 0.5-1.0 return probs, top_p -def test_correctness(batch_size=1, vocab_size=128256, num_trials=10000): - """Test correctness by comparing sampling frequency with expected renormalized probabilities. - Uses the same approach as FlashInfer's test_top_p_sampling_freq.""" +def test_correctness(batch_size=8, vocab_size=128256, num_trials=10000): + """Test correctness by comparing with FlashInfer implementation.""" print(f"\n{'=' * 60}") - print("Testing correctness against expected probabilities") - print(f"batch_size={batch_size}, vocab_size={vocab_size}, num_trials={num_trials}") + print("Testing correctness against FlashInfer") + print(f"batch_size={batch_size}, num_trials={num_trials}") print(f"{'=' * 60}") device = "cuda" torch.manual_seed(42) - # Generate inputs probs, top_p = generate_random_inputs(batch_size, vocab_size, "peaked", device) - # Count frequencies for both implementations ref_counter = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device=device) fi_counter = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device=device) @@ -91,26 +53,21 @@ def test_correctness(batch_size=1, vocab_size=128256, num_trials=10000): if trial % progress_interval == 0: print(f" Trial {trial}/{num_trials}...") - # Reference implementation torch.manual_seed(42 + trial) ref_samples = run(probs, top_p) for i in range(batch_size): ref_counter[i, ref_samples[i]] += 1 - # FlashInfer implementation torch.manual_seed(42 + trial) fi_samples = flashinfer.sampling.top_p_sampling_from_probs(probs, top_p) for i in range(batch_size): fi_counter[i, fi_samples[i]] += 1 - # Calculate frequencies ref_freq = ref_counter.float() / num_trials fi_freq = fi_counter.float() / num_trials - # Calculate cosine similarity similarities = [] for i in range(batch_size): - # Only compare tokens that were sampled at least once mask = (ref_freq[i] > 0) | (fi_freq[i] > 0) if mask.sum() > 0: ref = ref_freq[i][mask] @@ -122,7 +79,6 @@ def test_correctness(batch_size=1, vocab_size=128256, num_trials=10000): avg_similarity = sum(similarities) / len(similarities) print(f"\n Average cosine similarity: {avg_similarity:.4f}") - # Check similarity assert avg_similarity > 0.95, f"Implementations diverge too much: {avg_similarity:.4f} < 0.95" print(" Correctness test passed!") @@ -131,20 +87,12 @@ def test_correctness(batch_size=1, vocab_size=128256, num_trials=10000): def main(): """Run comprehensive tests for top_p_sampling_from_probs.""" - print("Testing Top-P (Nucleus) Sampling from Probabilities") + print("Testing Top-P Sampling from Probabilities (from definition)") all_passed = True - # Test correctness by comparing with FlashInfer try: - # Test with different configurations (matching FlashInfer's approach) - # Test different p values with batch_size=1 for efficiency - test_configs = [ - # (batch_size, vocab_size, num_trials) - (2, 128256, 10000), - (4, 129280, 10000), - (8, 151936, 10000), - ] + test_configs = [(2, 128256, 10000), (4, 129280, 10000), (8, 151936, 10000)] for batch_size, vocab_size, num_trials in test_configs: if not test_correctness(batch_size, vocab_size, num_trials): @@ -154,7 +102,6 @@ def main(): print(f"Correctness test failed: {e}") all_passed = False - # Summary print(f"\n{'=' * 60}") if all_passed: print("All tests passed!") diff --git a/flashinfer_trace/tests/references/test_utils.py b/flashinfer_trace/tests/references/test_utils.py new file mode 100644 index 00000000..39d105b8 --- /dev/null +++ b/flashinfer_trace/tests/references/test_utils.py @@ -0,0 +1,262 @@ +""" +Utility functions for reference implementation testing. + +Provides reusable tensor comparison, error reporting, and definition loading functions. +""" + +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Optional + +import torch +import torch.nn.functional as F + +from flashinfer_bench.data import Definition, load_json_file + +# Path to definitions directory (relative to flashinfer_trace/) +DEFINITIONS_DIR = Path(__file__).parent.parent.parent / "definitions" + + +def load_definition(name: str) -> Definition: + """ + Load a definition by name from definitions directory. + + Searches through all op_type subdirectories to find the matching definition file. + + Args: + name: The definition name (e.g., "rmsnorm_h128", "gqa_paged_decode_h32_kv8_d128_ps1") + + Returns: + Definition object loaded from the JSON file + + Raises: + FileNotFoundError: If no definition with the given name is found + """ + for op_dir in DEFINITIONS_DIR.iterdir(): + if op_dir.is_dir(): + def_file = op_dir / f"{name}.json" + if def_file.exists(): + return load_json_file(Definition, def_file) + raise FileNotFoundError(f"Definition {name} not found in {DEFINITIONS_DIR}") + + +def compile_reference(reference_code: str) -> Callable: + """ + Compile reference implementation code to a callable function. + + The reference code is expected to define a `run()` function that takes + the input tensors and returns the output tensors. + + Args: + reference_code: Python source code containing the run() function definition + + Returns: + The compiled run() function + + Example: + >>> definition = load_definition("rmsnorm_h128") + >>> run = compile_reference(definition.reference) + >>> output = run(hidden_states, weight) + """ + namespace = {"torch": torch, "math": math, "F": F} + exec(reference_code, namespace) + return namespace["run"] + + +def get_reference_run(definition_name: str) -> Callable: + """ + Convenience function to load a definition and compile its reference implementation. + + Args: + definition_name: The definition name (e.g., "rmsnorm_h128") + + Returns: + The compiled run() function from the definition's reference code + """ + definition = load_definition(definition_name) + return compile_reference(definition.reference) + + +@dataclass +class TensorComparisonMetrics: + """Metrics for comparing two tensors.""" + + max_abs_diff: float + max_rel_diff: float + mean_abs_diff: float + mean_rel_diff: float + cosine_similarity: float + mse: float + all_close: bool + + +def compare_tensors( + ref: torch.Tensor, + actual: torch.Tensor, + atol: float = 1e-2, + rtol: float = 5e-2, + eps: float = 1e-8, +) -> TensorComparisonMetrics: + """ + Compare two tensors and compute various error metrics. + + Args: + ref: Reference tensor + actual: Actual tensor to compare against reference + atol: Absolute tolerance for allclose check + rtol: Relative tolerance for allclose check + eps: Small epsilon value for numerical stability in relative difference + + Returns: + TensorComparisonMetrics object containing all comparison metrics + """ + # Convert to float32 for comparison + ref_f32 = ref.float() + actual_f32 = actual.float() + + # Compute absolute and relative differences + abs_diff = torch.abs(ref_f32 - actual_f32) + rel_diff = abs_diff / (torch.abs(actual_f32) + eps) + + # Compute error metrics + max_abs_diff = abs_diff.max().item() + max_rel_diff = rel_diff.max().item() + mean_abs_diff = abs_diff.mean().item() + mean_rel_diff = rel_diff.mean().item() + + # Compute cosine similarity + ref_flat = ref_f32.flatten() + actual_flat = actual_f32.flatten() + cosine_sim = F.cosine_similarity(ref_flat.unsqueeze(0), actual_flat.unsqueeze(0), dim=1).item() + + # Compute MSE + mse = ((ref_f32 - actual_f32) ** 2).mean().item() + + # Check if tensors are close + all_close = torch.allclose(ref_f32, actual_f32, atol=atol, rtol=rtol) + + return TensorComparisonMetrics( + max_abs_diff=max_abs_diff, + max_rel_diff=max_rel_diff, + mean_abs_diff=mean_abs_diff, + mean_rel_diff=mean_rel_diff, + cosine_similarity=cosine_sim, + mse=mse, + all_close=all_close, + ) + + +def print_comparison_metrics( + metrics: TensorComparisonMetrics, tensor_name: str = "Tensor", indent: str = "" +): + """ + Print tensor comparison metrics in a formatted way. + + Args: + metrics: TensorComparisonMetrics object + tensor_name: Name of the tensor for display + indent: String to prepend to each line for indentation + """ + print(f"{indent}{tensor_name} comparison:") + print(f"{indent} Max absolute difference: {metrics.max_abs_diff:.6e}") + print(f"{indent} Max relative difference: {metrics.max_rel_diff:.6e}") + print(f"{indent} Mean absolute difference: {metrics.mean_abs_diff:.6e}") + print(f"{indent} Mean relative difference: {metrics.mean_rel_diff:.6e}") + print(f"{indent} Cosine similarity: {metrics.cosine_similarity:.6f}") + print(f"{indent} MSE: {metrics.mse:.6e}") + + +def find_and_print_top_errors( + ref: torch.Tensor, + actual: torch.Tensor, + shape_names: Optional[tuple[str, ...]] = None, + top_k: int = 5, + tensor_name: str = "Tensor", +): + """ + Find and print top error locations for debugging. + + Args: + ref: Reference tensor + actual: Actual tensor to compare against reference + shape_names: Names for each dimension (e.g., ("batch", "heads", "dim")) + top_k: Number of top errors to print + tensor_name: Name of the tensor for display + """ + ref_f32 = ref.float() + actual_f32 = actual.float() + + abs_diff = torch.abs(ref_f32 - actual_f32) + flat_abs_diff = abs_diff.flatten() + + k = min(top_k, flat_abs_diff.numel()) + if k == 0: + return + + top_errors, top_indices = torch.topk(flat_abs_diff, k) + + print(f"\nTop {k} {tensor_name} error locations:") + for i in range(k): + idx = top_indices[i].item() + + # Convert flat index to multi-dimensional indices + indices = [] + remaining = idx + for dim_size in reversed(ref.shape): + indices.append(remaining % dim_size) + remaining //= dim_size + indices = list(reversed(indices)) + + # Format indices with names if provided + if shape_names and len(shape_names) == len(indices): + index_str = ", ".join(f"{name}={val}" for name, val in zip(shape_names, indices)) + else: + index_str = ", ".join(str(i) for i in indices) + + ref_val = ref_f32.flatten()[idx].item() + actual_val = actual_f32.flatten()[idx].item() + + print( + f" [{index_str}]: " + f"ref={ref_val:.6f}, actual={actual_val:.6f}, diff={top_errors[i].item():.6e}" + ) + + +def compare_and_report( + ref: torch.Tensor, + actual: torch.Tensor, + tensor_name: str = "Output", + shape_names: Optional[tuple[str, ...]] = None, + atol: float = 1e-2, + rtol: float = 5e-2, + show_top_errors: bool = True, + top_k: int = 5, +) -> bool: + """ + Compare two tensors, print metrics, and optionally show top errors. + + This is a convenience function that combines compare_tensors, print_comparison_metrics, + and find_and_print_top_errors. + + Args: + ref: Reference tensor + actual: Actual tensor to compare against reference + tensor_name: Name of the tensor for display + shape_names: Names for each dimension (e.g., ("batch", "heads", "dim")) + atol: Absolute tolerance for allclose check + rtol: Relative tolerance for allclose check + show_top_errors: Whether to print top error locations if tensors don't match + top_k: Number of top errors to print + + Returns: + True if tensors match within tolerance, False otherwise + """ + metrics = compare_tensors(ref, actual, atol=atol, rtol=rtol) + print(f"\n{tensor_name} comparison:") + print_comparison_metrics(metrics, tensor_name="", indent=" ") + + if not metrics.all_close and show_top_errors: + find_and_print_top_errors(ref, actual, shape_names, top_k, tensor_name) + + return metrics.all_close