Skip to content

Conversation

@Ubospica
Copy link
Collaborator

@Ubospica Ubospica commented Jan 18, 2026

This PR provides a framework for dataset validation, especially definition schema and reference correctness.

This PR retires tests in flashinfer_trace/tests/references. Later submission should write tests in flashinfer_trace/tests/definitions instead.

cc @yyihuang this can further simplify your definition submission workflow. It should not be very hard to convert the old tests with genai.

Signed-off-by: Ubospica [email protected]

Summary by CodeRabbit

  • New Features

    • Added a comprehensive definition-driven testing framework and many new test suites (GQA paged/ragged, MLA paged, MoE FP8, NSA sparse, RMSNorm, sampling).
  • Tests

    • Introduced CUDA-aware test gating and a reusable requires_torch_cuda decorator; converted many tests to use it.
    • Added numerous pytest modules and reference tests.
    • Integrated a Ruff pre-commit hook for automated linting.
  • Documentation

    • Added a "Testing" guide to the README with pytest instructions.
  • Chores

    • Updated linting and pytest config patterns.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Ubospica <[email protected]>
Signed-off-by: Ubospica <[email protected]>
Signed-off-by: Ubospica <[email protected]>
Signed-off-by: Ubospica <[email protected]>
Signed-off-by: Ubospica <[email protected]>
Signed-off-by: Ubospica <[email protected]>
Signed-off-by: Ubospica <[email protected]>
Signed-off-by: Ubospica <[email protected]>
Signed-off-by: Ubospica <[email protected]>
@vercel
Copy link
Contributor

vercel bot commented Jan 18, 2026

@Ubospica is attempting to deploy a commit to the Zihao Ye's projects Team on Vercel.

A member of the Team first needs to authorize it.

@coderabbitai
Copy link

coderabbitai bot commented Jan 18, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds a definition-driven testing framework, comparators, CUDA gating decorator, many new definition-based tests and input generators, reference test cleanups, and supporting config updates across test infra and linting configs.

Changes

Cohort / File(s) Summary
Testing Framework Core
flashinfer_bench/testing/__init__.py, flashinfer_bench/testing/comparators.py, flashinfer_bench/testing/definition.py, flashinfer_bench/testing/pytest_config.py
New testing core: CompareResult, Comparator families (TensorComparator, MultiOutputComparator, HitRatioComparator); DefinitionRunner and DefinitionTest for running/parametrizing definition-based tests; requires_torch_cuda pytest gating decorator.
Utilities
flashinfer_bench/utils.py
Renamed is_cuda_available()is_torch_cuda_available() and made import-safe (handles missing torch).
flashinfer_trace tests — definitions
flashinfer_trace/tests/definitions/* (many files: test_gqa_*.py, test_mla_paged.py, test_moe_fp8.py, test_nsa_sparse.py, test_rmsnorm.py, test_sampling.py, ...)
Added numerous definition-based test modules, input generators, and baseline_fn implementations for GQA (paged/ragged), MLA, MoE FP8, NSA sparse, RMSNorm, and sampling tests; per-test configs and tolerances included.
flashinfer_trace tests — references & infra
flashinfer_trace/tests/references/*, flashinfer_trace/tests/conftest.py, flashinfer_trace/tests/definitions/__init__.py
Small cleanup in reference tests (removed unnecessary f-strings, unused imports/vars); added conftest to set FIB_DATASET_PATH and exclude references from collection.
Test decorator migration
tests/bench/test_benchmark.py, tests/compile/test_torch_builder.py, tests/compile/test_triton_builder.py, tests/compile/test_tvm_ffi_builder.py
Replaced @pytest.mark.requires_torch_cuda markers with the new @requires_torch_cuda decorator import.
Test collection simplification
tests/conftest.py
Removed previous CUDA collection-skipping helpers and related hooks.
Linting & config
.pre-commit-config.yaml, pyproject.toml, docs/api/conf.py
Added Ruff pre-commit hook entry; expanded Ruff include patterns to cover flashinfer_trace and docs; small docs config string simplification.

Sequence Diagram(s)

sequenceDiagram
    participant Pytest as DefinitionTest (pytest)
    participant Runner as DefinitionRunner
    participant Ref as ReferenceCode
    participant BL as BaselineFn
    participant Comp as Comparator
    participant Result as CompareResult

    Pytest->>Runner: run(config)
    Runner->>Runner: _extract_reference()
    Runner->>Runner: _generate_inputs(config)
    Runner->>Ref: invoke run(inputs)
    Ref-->>Runner: ref_output
    Runner->>BL: invoke baseline_fn(inputs)
    BL-->>Runner: baseline_output
    Runner->>Comp: compare(ref_output, baseline_output)
    Comp->>Comp: compute stats, diffs, hit-ratio
    alt within tolerance
        Comp-->>Result: CompareResult(passed=True, stats)
    else exceeds tolerance
        Comp->>Comp: gather failure details
        Comp-->>Result: CompareResult(passed=False, stats, details)
    end
    Result-->>Runner: CompareResult
    Runner-->>Pytest: reflect pass/fail
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • Ubospica
  • YiyanZhai
  • xslingcn

Poem

🐰
I hopped through tests both wide and deep,
Built comparators, guards to keep,
Definitions, runners, inputs spawned—
Baselines checked from dusk to dawn,
Now CI hums while carrots sleep! 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 78.18% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: Dataset Validation Framework' is specific and clearly describes the main feature introduced in this PR - a comprehensive dataset validation framework for definition schema and reference correctness.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Ubospica, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the testing infrastructure by introducing a robust dataset validation framework. It provides standardized tools and methodologies for verifying the correctness of FlashInfer operations against their reference implementations, ensuring data integrity and reliable performance benchmarks. The changes streamline the testing process, making it easier to add new tests and maintain existing ones, while also improving code quality through updated linting rules.

Highlights

  • Dataset Validation Framework: Introduced a new framework for dataset validation, focusing on schema definition and correctness of reference implementations. This framework provides a structured way to test various FlashInfer operations.
  • Testing Utilities: Added new testing utilities including Comparator classes (TensorComparator, MultiOutputComparator, HitRatioComparator) and a DefinitionTest base class for pytest-compatible, definition-driven testing.
  • Test Migration and Expansion: Migrated existing tests from flashinfer_trace/tests/references to the new flashinfer_trace/tests/definitions structure. New comprehensive test cases were added for GQA paged decode/prefill, MLA paged attention, MoE FP8, RMSNorm, and sampling operations.
  • Tooling and Configuration Updates: Integrated ruff as a pre-commit hook for unused import detection (F401) and updated pyproject.toml to reflect the new testing structure and ignore old reference tests.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an excellent and comprehensive validation framework for dataset definitions. The DefinitionTest base class and associated comparators provide a powerful and extensible way to ensure schema and reference correctness. The addition of new tests for various kernels using this framework is a great step towards improving the reliability of the definitions.

I've found a few issues: a critical configuration error in the pre-commit setup, a potential bug in the MultiOutputComparator, and some overly lenient assertions in the new sampling tests. My detailed comments provide suggestions to address these points. Overall, this is a very strong contribution that significantly enhances the project's testing capabilities.

args: ["--profile=black"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The specified rev: v0.9.3 for ruff-pre-commit is not a valid version tag in the astral-sh/ruff-pre-commit repository. This will cause the pre-commit hook installation to fail. Please use a valid version, for instance, the latest version v0.5.0.

    rev: v0.5.0

Comment on lines 142 to 144
self.comparators = comparators or {
name: TensorComparator(atol, rtol) for name in output_names
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current initialization logic for self.comparators can lead to unexpected behavior. If an empty dictionary is passed for comparators, self.comparators will also be an empty dictionary. This causes the compare method to fall back to TensorComparator() with its own default tolerances, ignoring the atol and rtol passed to MultiOutputComparator. A more robust approach is to set up default comparators for all outputs first, and then override them with any custom comparators provided.

Suggested change
self.comparators = comparators or {
name: TensorComparator(atol, rtol) for name in output_names
}
self.comparators = {name: TensorComparator(atol, rtol) for name in output_names}
if comparators:
self.comparators.update(comparators)

if name not in ref_dict or name not in base_dict:
continue

comparator = self.comparators.get(name, TensorComparator())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With the suggested change to __init__, self.comparators is now guaranteed to have a key for every name in self.output_names. Therefore, you can directly access the comparator with self.comparators[name] and remove the fallback. This simplifies the code and avoids the potential bug of using a TensorComparator with default tolerances.

Suggested change
comparator = self.comparators.get(name, TensorComparator())
comparator = self.comparators[name]

Comment on lines +111 to +112
overlap = torch.isin(sampled_tokens, top_k_indices).float().mean()
assert overlap > 0.95, f"Too many samples outside top-k set: {overlap:.2%}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert overlap > 0.95 is too lenient for top-k sampling. The top_k_sampling_from_probs function should only sample from the set of top-k tokens. Therefore, all sampled tokens must be within the top_k_indices set, and the overlap should be exactly 1.0. A looser assertion might hide subtle bugs in the sampling implementation.

Suggested change
overlap = torch.isin(sampled_tokens, top_k_indices).float().mean()
assert overlap > 0.95, f"Too many samples outside top-k set: {overlap:.2%}"
overlap = torch.isin(sampled_tokens, top_k_indices).float().mean()
assert overlap == 1.0, f"All sampled tokens must be within the top-k set, but overlap was {overlap:.2%}"

Comment on lines +152 to +153
overlap = torch.isin(sampled_tokens, top_k_indices).float().mean()
assert overlap > 0.95, f"Too many samples outside top-k set: {overlap:.2%}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert overlap > 0.95 is too lenient. When top_k_first is used, the sampling is restricted to the top-k tokens. All sampled tokens must be within this set, so the overlap should be exactly 1.0. A looser assertion could mask correctness issues in the sampling logic.

Suggested change
overlap = torch.isin(sampled_tokens, top_k_indices).float().mean()
assert overlap > 0.95, f"Too many samples outside top-k set: {overlap:.2%}"
overlap = torch.isin(sampled_tokens, top_k_indices).float().mean()
assert overlap == 1.0, f"All sampled tokens must be within the top-k set, but overlap was {overlap:.2%}"

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer_bench/data/definition.py (1)

135-156: Critical: Missing None check causes TypeError when reference is omitted.

The reference field is now optional (defaults to None), but _validate_reference_code unconditionally calls ast.parse(self.reference) at line 146. When reference is None, this raises TypeError: expected str, bytes or os.PathLike object, not NoneType.

🐛 Proposed fix
 `@model_validator`(mode="after")
 def _validate_reference_code(self) -> Definition:
     """Validate that reference contains valid Python code with a 'run' function.

     Raises
     ------
     ValueError
         If the reference code is not valid Python syntax or doesn't contain
         a top-level 'run' function.
     """
+    if self.reference is None:
+        return self
+
     try:
         mod = ast.parse(self.reference, mode="exec")
     except SyntaxError as e:
         raise ValueError(f"Reference must be valid Python code: {e}") from e

     # Check for 'run' function
     has_run_func = any(
         isinstance(node, ast.FunctionDef) and node.name == "run" for node in mod.body
     )
     if not has_run_func:
         raise ValueError("Reference must define a top-level function named 'run'")
     return self
🤖 Fix all issues with AI agents
In `@flashinfer_bench/data/definition.py`:
- Around line 123-126: The _validate_reference_code() validator currently calls
ast.parse(self.reference) without handling None; update _validate_reference_code
to return early (or accept None) when self.reference is None and only parse when
non-None, and then update all downstream callers that assume a non-None
reference (kernel_generator_prompts.py usage in the f-string, registry.py where
SourceFile is created, and evaluators build_reference() calls in sampling.py and
default.py) to check for None and handle the no-reference case (e.g., skip
embedding/passing or raise a clear validation error), or alternatively revert
the dataclass field reference: Optional[...] back to a required non-Optional
type so _validate_reference_code, build_reference, kernel_generator_prompts, and
registry usage can safely assume presence.

In `@flashinfer_bench/testing/comparators.py`:
- Around line 163-172: The tuple-to-dict conversion silently truncates if the
tuple lengths differ from self.output_names; update the conversions where
ref_output and baseline_output are mapped (the blocks using
dict(zip(self.output_names, ref_output)) and dict(zip(self.output_names,
baseline_output))) to call zip(self.output_names, ..., strict=True) so a
ValueError is raised on length mismatch, ensuring mismatched output counts are
caught early.

In `@flashinfer_bench/testing/definition.py`:
- Around line 108-124: The current tensor-generation branch uses
torch.randint(..., dtype=torch.bool) which fails because torch.randint does not
accept torch.bool; in the branch handling tensor generation for boolean dtypes
(where dtype == torch.bool) change the creation to use torch.randint(0, 2,
shape, device=self.device) and then cast the result to boolean (e.g.,
.to(torch.bool) or .bool()) before assigning to inputs[name], ensuring shape,
device and the inputs dict usage remain unchanged.

In `@flashinfer_trace/tests/definitions/test_moe_fp8.py`:
- Around line 34-40: The constant WORKLOAD_JSONL_PATH (and related TRACE_ROOT)
is defined but never used; either remove WORKLOAD_JSONL_PATH (and TRACE_ROOT if
no other use) to eliminate dead code, or if it’s being kept for future use, mark
it explicitly with a TODO comment or rename to _WORKLOAD_JSONL_PATH to signal
intentional unused status so linters won’t complain; update the top of
test_moe_fp8.py accordingly and ensure no other references rely on the constant
before removing it.
🧹 Nitpick comments (18)
.pre-commit-config.yaml (1)

51-55: Policy inconsistency: Ruff linting in pre-commit contradicts stated guidelines.

Lines 15-16 state that "pre-commit hooks should only be used for formatting, but not for linting." However, F401 (unused imports) is a linting rule. Consider either:

  1. Updating the comment to reflect that auto-fixable lint rules are acceptable, or
  2. Moving this check to CI only

Additionally, --select=F401 limits pre-commit to only unused imports, while pyproject.toml configures broader Ruff rules (C, E, F, W, TID252). This asymmetry could cause CI failures that pre-commit doesn't catch locally.

flashinfer_trace/tests/definitions/test_sampling.py (1)

46-51: Consider using scatter_add_ for more efficient counter updates.

The per-batch loop to update counters can be vectorized using scatter_add_, which is more idiomatic and potentially faster for GPU tensors.

♻️ Suggested optimization (applies to all three test classes)
         for trial in range(num_trials):
             torch.manual_seed(42 + trial)
             # FlashInfer implementation
             fi_samples = flashinfer.sampling.top_p_sampling_from_probs(probs, top_p)
-            for i in range(batch_size):
-                fi_counter[i, fi_samples[i]] += 1
+            fi_counter.scatter_add_(
+                1, fi_samples.unsqueeze(1), torch.ones_like(fi_samples.unsqueeze(1), dtype=torch.int32)
+            )

Note: You'd need to verify fi_samples shape matches expectations (should be (batch_size,)).

flashinfer_trace/tests/definitions/test_gqa_ragged_prefill.py (2)

27-29: Consider vectorizing the KV length generation loop.

The loop can be replaced with vectorized operations for better performance and cleaner code:

♻️ Suggested refactor
     # Generate KV lengths >= query lengths for causal attention
-    kv_lens = torch.zeros(batch_size, dtype=torch.int32)
-    for i in range(batch_size):
-        kv_lens[i] = torch.randint(q_lens[i].item(), max_kv_len + 1, (1,)).item()
+    kv_lens = torch.randint(1, max_kv_len + 1, (batch_size,), dtype=torch.int32)
+    # Ensure KV lengths >= query lengths for causal attention
+    kv_lens = torch.maximum(kv_lens, q_lens)

64-68: Static analysis: Annotate mutable class attribute with ClassVar.

The configs list is a mutable class attribute shared across all instances. While this works correctly for the test framework, adding a type annotation improves clarity.

♻️ Suggested fix
+from typing import ClassVar
+
 class TestGQARaggedPrefillH32KV4(DefinitionTest):
     """Test GQA ragged prefill with 32 QO heads and 4 KV heads."""

     definition_path = "definitions/gqa_ragged/gqa_ragged_prefill_causal_h32_kv4_d128.json"
-    configs = [
+    configs: ClassVar[list[dict]] = [
         {"batch_size": 1, "max_q_len": 8, "max_kv_len": 16},
         {"batch_size": 4, "max_q_len": 16, "max_kv_len": 32},
         {"batch_size": 8, "max_q_len": 32, "max_kv_len": 64},
     ]
flashinfer_trace/tests/definitions/test_moe_fp8.py (2)

122-123: Consider renaming ambiguous variable I.

The variable I can be visually confused with 1 or l in some fonts. Consider using a more descriptive name like inter or intermediate.

♻️ Suggested fix
-    T, H, I = seq_len, hidden_size, intermediate_size
+    T, H, inter = seq_len, hidden_size, intermediate_size
     E_global, E_local = num_experts_global, num_local_experts
     ...
-    w13_bf16 = torch.randn(E_local, 2 * I, H, dtype=torch.bfloat16, device=device)
-    w2_bf16 = torch.randn(E_local, H, I, dtype=torch.bfloat16, device=device)
+    w13_bf16 = torch.randn(E_local, 2 * inter, H, dtype=torch.bfloat16, device=device)
+    w2_bf16 = torch.randn(E_local, H, inter, dtype=torch.bfloat16, device=device)

101-106: Inconsistent handling of zero amax compared to _fp8_block_quant_1d.

Line 103 uses a Python if statement with a 0-dim tensor, while _fp8_block_quant_1d uses torch.where at line 72. While this works because PyTorch evaluates 0-dim tensors in boolean context, using torch.where would be more consistent and explicit.

♻️ Suggested fix for consistency
                 blk = w_f32[(*sel, rs, cs)]
                 amax = torch.amax(torch.abs(blk))
-                s = (amax / max_fp8) if amax > 0 else torch.tensor(1.0, device=w_bf16.device)
+                s = torch.where(amax > 0, amax / max_fp8, torch.ones_like(amax))
                 q = (blk / s).to(torch.float8_e4m3fn)
flashinfer_bench/testing/pytest_config.py (1)

7-10: Post-assignment docstring is non-idiomatic.

The string literal on line 10 after the assignment is not recognized as a docstring by documentation tools (only classes/functions/modules get docstrings). Consider using a comment instead.

♻️ Suggested fix
+# Marker to skip tests when PyTorch CUDA is not available.
 requires_torch_cuda = pytest.mark.skipif(
     not is_torch_cuda_available(), reason="CUDA not available from PyTorch"
 )
-"""Marker to skip tests when PyTorch CUDA is not available."""
flashinfer_trace/tests/conftest.py (1)

11-14: Silence unused-arg lint in pytest_configure.
Optional cleanup to avoid ARG001 while keeping pytest’s expected signature.

♻️ Suggested tweak
-def pytest_configure(config):
+def pytest_configure(_config):
     """Set FIB_DATASET_PATH to flashinfer_trace root for these tests."""
     trace_root = Path(__file__).parent.parent
     os.environ["FIB_DATASET_PATH"] = str(trace_root)
flashinfer_bench/testing/__init__.py (1)

13-21: Clean public API surface.

The exports are well-organized and cover the key testing utilities. Consider sorting __all__ alphabetically for consistency with isort conventions (flagged by RUF022), though this is a minor style preference.

♻️ Optional: Sort __all__ alphabetically
 __all__ = [
-    "DefinitionTest",
-    "CompareResult",
     "Comparator",
-    "TensorComparator",
-    "MultiOutputComparator",
+    "CompareResult",
+    "DefinitionTest",
     "HitRatioComparator",
+    "MultiOutputComparator",
+    "TensorComparator",
     "requires_torch_cuda",
 ]
flashinfer_trace/tests/definitions/test_gqa_paged_prefill.py (1)

99-128: Consider extracting shared baseline logic.

Both baseline_fn implementations are nearly identical, differing only in num_kv_heads (4 vs 8). This duplication could be reduced by extracting a shared helper.

♻️ Optional: Extract shared baseline helper
def _run_flashinfer_prefill(
    q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale,
    num_qo_heads: int, num_kv_heads: int
):
    """Shared FlashInfer prefill baseline implementation."""
    device = q.device
    batch_size = qo_indptr.shape[0] - 1

    workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
    last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device)

    paged_kv_cache = torch.stack([k_cache, v_cache], dim=1)

    wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
        workspace_buffer, kv_layout="NHD"
    )
    wrapper.plan(
        qo_indptr=qo_indptr,
        paged_kv_indptr=kv_indptr,
        paged_kv_indices=kv_indices,
        paged_kv_last_page_len=last_page_len,
        num_qo_heads=num_qo_heads,
        num_kv_heads=num_kv_heads,
        head_dim_qk=128,
        head_dim_vo=128,
        page_size=1,
        causal=True,
        sm_scale=float(sm_scale),
        q_data_type=torch.bfloat16,
        kv_data_type=torch.bfloat16,
    )

    return wrapper.run(q, paged_kv_cache, return_lse=True)

Then in each class:

def baseline_fn(self, q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale):
    return _run_flashinfer_prefill(
        q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale,
        num_qo_heads=32, num_kv_heads=4  # or 8 for the other class
    )

Also applies to: 153-182

flashinfer_trace/tests/definitions/test_rmsnorm.py (1)

24-102: Consider reducing test class boilerplate.

All four test classes share identical atol, rtol, and baseline_fn implementations, differing only in definition_path, configs, and hidden_size. This pattern could be simplified.

♻️ Optional: Use a factory or parametrized approach

One approach is to create a factory function:

def create_rmsnorm_test_class(hidden_size: int, configs: list):
    """Factory to create RMSNorm test classes."""
    class TestRMSNorm(DefinitionTest):
        definition_path = f"definitions/rmsnorm/rmsnorm_h{hidden_size}.json"
        atol = 8e-3
        rtol = 1e-2

        `@staticmethod`
        def input_generator(**config):
            return generate_rmsnorm_inputs(batch_size=config["batch_size"], hidden_size=hidden_size)

        def baseline_fn(self, hidden_states, weight):
            return flashinfer.norm.rmsnorm(
                hidden_states.contiguous(), weight.contiguous(), eps=1e-6
            )

    TestRMSNorm.configs = configs
    TestRMSNorm.__name__ = f"TestRMSNormH{hidden_size}"
    return TestRMSNorm

TestRMSNormH128 = create_rmsnorm_test_class(128, [{"batch_size": i} for i in [1, 4, 8, 16, 32]])
TestRMSNormH2048 = create_rmsnorm_test_class(2048, [{"batch_size": i} for i in [1, 4, 8]])
TestRMSNormH4096 = create_rmsnorm_test_class(4096, [{"batch_size": i} for i in [1, 4, 8]])
TestRMSNormH7168 = create_rmsnorm_test_class(7168, [{"batch_size": i} for i in [1, 4, 8]])

This is optional—the current explicit approach is more readable if you prefer clarity over conciseness.

flashinfer_trace/tests/definitions/test_gqa_paged_decode.py (1)

64-114: Consider extracting common baseline logic to reduce duplication.

The baseline_fn implementations in both test classes are nearly identical, differing only in num_kv_heads (4 vs 8). Consider extracting a shared helper to reduce maintenance burden.

♻️ Suggested helper extraction
def _run_gqa_decode_baseline(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale, num_kv_heads):
    """Shared baseline implementation for GQA paged decode."""
    device = q.device
    batch_size = q.shape[0]

    workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
    kv_last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device)

    wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout="NHD")
    wrapper.plan(
        indptr=kv_indptr,
        indices=kv_indices,
        last_page_len=kv_last_page_len,
        num_qo_heads=32,
        num_kv_heads=num_kv_heads,
        head_dim=128,
        page_size=1,
        pos_encoding_mode="NONE",
        q_data_type=torch.bfloat16,
        kv_data_type=torch.bfloat16,
        sm_scale=float(sm_scale),
    )

    return wrapper.run(q, (k_cache, v_cache), return_lse=True)

Then each class can call:

def baseline_fn(self, q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale):
    return _run_gqa_decode_baseline(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale, num_kv_heads=4)
flashinfer_trace/tests/definitions/test_mla_paged.py (2)

73-75: Use vectorized operation instead of loop for ensuring kv_len >= q_len.

The loop modifies individual tensor elements, which is inefficient. Use torch.maximum for a vectorized approach.

♻️ Proposed fix
     # Ensure kv_len >= q_len for causal attention
-    for i in range(batch_size):
-        kv_lens[i] = max(kv_lens[i], q_lens[i])
+    kv_lens = torch.maximum(kv_lens, q_lens)

44-45: Consider using math.sqrt for consistency.

The file already has math usage in other parts of the codebase (e.g., definition.py). Using math.sqrt instead of np.sqrt would allow removing the numpy import for this single usage.

♻️ Proposed fix
-import numpy as np
+import math
...
     # MLA scale
-    sm_scale = 1.0 / np.sqrt(128 + head_dim_kpe)
+    sm_scale = 1.0 / math.sqrt(128 + head_dim_kpe)

Apply the same change at line 99.

flashinfer_trace/tests/definitions/test_nsa_sparse.py (2)

182-188: Prefix unused variable with underscore.

fi_max_logits is unpacked but never used. Prefix with underscore to indicate intentional discard.

♻️ Proposed fix
-        fi_output_full, fi_max_logits, fi_lse_full = flash_mla_sparse_fwd(
+        fi_output_full, _fi_max_logits, fi_lse_full = flash_mla_sparse_fwd(

Apply the same fix at line 250.


154-197: Consider extracting shared baseline logic for decode and prefill.

The baseline_fn implementations in TestNSASparseDecode and TestNSASparsePrefill are nearly identical, differing only in how batch dimension is determined (q_nope.shape[0] meaning either batch_size or total_num_tokens). A shared helper could reduce duplication.

flashinfer_bench/testing/comparators.py (1)

178-183: Consider logging or warning when expected outputs are missing.

When name not in ref_dict or name not in base_dict, the comparison silently skips that output. This could mask issues where an expected output is missing. Consider at least emitting a warning or including this in the details.

♻️ Suggested improvement
         for name in self.output_names:
             if name not in ref_dict or name not in base_dict:
+                all_details.append(f"[{name}] Skipped: missing from ref or baseline")
                 continue
flashinfer_bench/testing/definition.py (1)

73-80: Document the security context for exec() usage.

While exec() is necessary here to run definition reference code, consider adding a comment explaining the trust model (i.e., that definition files are trusted/controlled by the project).

📝 Suggested documentation
     def _extract_reference(self) -> Callable:
         """Extract the run() function from definition's reference field.
+
+        Note: exec() is used here on trusted definition code from the project's
+        dataset. Definition files should not contain untrusted user input.

         Raises
         ------
         ValueError
             If reference code does not contain a 'run' function.
         """
         code = self.definition["reference"]
-        namespace: Dict[str, Any] = {"torch": torch, "math": math}
-        exec(code, namespace)  # noqa: S102
+        # Provide limited namespace for reference code execution
+        namespace: Dict[str, Any] = {"torch": torch, "math": math}
+        exec(code, namespace)  # noqa: S102 - trusted definition code

Comment on lines +108 to +124
if shape is None:
# Scalar value
if dtype in (torch.float32, torch.float16, torch.bfloat16):
inputs[name] = 1.0
elif dtype == torch.bool:
inputs[name] = True
else:
inputs[name] = 1
else:
# Tensor
if dtype == torch.bool:
tensor = torch.randint(0, 2, shape, dtype=dtype, device=self.device)
elif dtype in (torch.int64, torch.int32, torch.int16, torch.int8):
tensor = torch.randint(0, 100, shape, dtype=dtype, device=self.device)
else:
tensor = torch.randn(shape, device=self.device).to(dtype)
inputs[name] = tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

torch.randint dtype torch.bool PyTorch 2.8 documentation

💡 Result:

Short answer: No — in PyTorch 2.8 torch.randint is for generating integer tensors; dtype should be an integer dtype (e.g. torch.int64, torch.int32, torch.uint8, etc.). torch.bool is not an appropriate dtype for torch.randint and is not supported for this integer-sampling API. [1][2]

Sources:

  • torch.randint documentation (PyTorch 2.8/2.9). [1]
  • Tensor dtypes / dtype table (shows torch.bool is a boolean dtype with different/op-limited semantics). [2]

Fix bool tensor generation: torch.randint does not support dtype=torch.bool.

torch.randint() only accepts integer dtypes. Use torch.randint(0, 2, shape, device=self.device).to(torch.bool) instead to generate boolean tensors.

🤖 Prompt for AI Agents
In `@flashinfer_bench/testing/definition.py` around lines 108 - 124, The current
tensor-generation branch uses torch.randint(..., dtype=torch.bool) which fails
because torch.randint does not accept torch.bool; in the branch handling tensor
generation for boolean dtypes (where dtype == torch.bool) change the creation to
use torch.randint(0, 2, shape, device=self.device) and then cast the result to
boolean (e.g., .to(torch.bool) or .bool()) before assigning to inputs[name],
ensuring shape, device and the inputs dict usage remain unchanged.

Signed-off-by: Ubospica <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer_trace/tests/definitions/test_moe_fp8.py`:
- Line 113: Rename the ambiguous single-letter variable I in the tuple
assignment "T, H, I = seq_len, hidden_size, intermediate_size" to a clearer name
(e.g., intermediate_size_var or intermediate) and update every usage of I within
the same test function (including the occurrences currently at lines noted near
126 and 127) to the new name; ensure the new identifier replaces I wherever it's
referenced (alongside T and H) so the test compiles and Ruff E741 is resolved
while preserving semantic meaning tied to intermediate_size.
🧹 Nitpick comments (1)
flashinfer_trace/tests/definitions/test_moe_fp8.py (1)

92-95: Avoid per‑block CPU↔GPU sync in _fp8_block_quant_2d.
if amax > 0 on a CUDA scalar forces sync and allocates a tensor per block. Use a device‑side torch.where/ones_like to keep it on‑device and reduce overhead.

♻️ Proposed refactor
-                s = (amax / max_fp8) if amax > 0 else torch.tensor(1.0, device=w_bf16.device)
+                s = torch.where(amax > 0, amax / max_fp8, torch.ones_like(amax))

device: str = "cuda",
):
"""Generate random inputs for MoE FP8 testing."""
T, H, I = seq_len, hidden_size, intermediate_size
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's find and examine the test file
fd -a "test_moe_fp8.py" --type f

Repository: flashinfer-ai/flashinfer-bench

Length of output: 143


🏁 Script executed:

# Read the file to see the context around line 113
cat -n flashinfer_trace/tests/definitions/test_moe_fp8.py

Repository: flashinfer-ai/flashinfer-bench

Length of output: 8297


🌐 Web query:

Ruff E741 ambiguous variable name rule

💡 Result:

Ruff’s E741 (ambiguous-variable-name) flags use of the single characters l, O, or I as variable names because they can be confused with 1 and 0; it’s derived from pycodestyle. Use more descriptive names (or L instead of l). (docs.astral.sh)

Citations:


Rename I to avoid Ruff E741 (ambiguous variable name).

Ruff flags the single-letter variable I as ambiguous since it can be confused with 1. The variable appears on lines 113, 126, and 127 and should be renamed to a clearer name throughout the function.

🔧 Suggested rename
-    T, H, I = seq_len, hidden_size, intermediate_size
+    T, H, inter = seq_len, hidden_size, intermediate_size
@@
-    w13_bf16 = torch.randn(E_local, 2 * I, H, dtype=torch.bfloat16, device=device)
-    w2_bf16 = torch.randn(E_local, H, I, dtype=torch.bfloat16, device=device)
+    w13_bf16 = torch.randn(E_local, 2 * inter, H, dtype=torch.bfloat16, device=device)
+    w2_bf16 = torch.randn(E_local, H, inter, dtype=torch.bfloat16, device=device)
🧰 Tools
🪛 Ruff (0.14.13)

113-113: Ambiguous variable name: I

(E741)

🤖 Prompt for AI Agents
In `@flashinfer_trace/tests/definitions/test_moe_fp8.py` at line 113, Rename the
ambiguous single-letter variable I in the tuple assignment "T, H, I = seq_len,
hidden_size, intermediate_size" to a clearer name (e.g., intermediate_size_var
or intermediate) and update every usage of I within the same test function
(including the occurrences currently at lines noted near 126 and 127) to the new
name; ensure the new identifier replaces I wherever it's referenced (alongside T
and H) so the test compiles and Ruff E741 is resolved while preserving semantic
meaning tied to intermediate_size.

@yyihuang yyihuang self-requested a review January 18, 2026 02:08
Copy link
Contributor

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ubospica can you document the usage in some claude skills under .claude/skills?

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In
`@flashinfer_trace/tests/references/test_moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.py`:
- Line 4: Remove the unused import statement "import numpy as np" from the test
file; locate the top-of-file import block (the line containing "import numpy as
np") and delete it so the file only imports required modules (PyTorch), keeping
dependencies minimal.

import json
from pathlib import Path

import numpy as np
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused import: numpy is not used in this file.

The np alias is never referenced. All tensor operations use PyTorch. Remove this import to keep dependencies minimal.

-import numpy as np
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import numpy as np
🤖 Prompt for AI Agents
In
`@flashinfer_trace/tests/references/test_moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.py`
at line 4, Remove the unused import statement "import numpy as np" from the test
file; locate the top-of-file import block (the line containing "import numpy as
np") and delete it so the file only imports required modules (PyTorch), keeping
dependencies minimal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants