Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Prevents TE/JAX from changing FFI interfaces and breaking backwards compatibility with older HLO on accident.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add new tests in test_custom_call.py and associated HLO text file to ensure both of the following
    • Load HLO text file, compile, and execute with dummy values
    • Ensure all TE FFI registrations are tested in the HLO files. This will ensure we are properly adding HLO tests for any new FFI interfaces

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 16, 2025

Greptile Summary

Added FFI backwards compatibility tests for JAX to prevent accidental breaking changes to FFI interfaces. The PR introduces three test methods:

  • test_generate_hlo - generates StableHLO text files by running TE operations (manually triggered via env var)
  • test_ffi_compatibility - loads and executes saved HLO files to verify FFI interfaces still work
  • test_all_primitive_ffi_tested - ensures all registered FFI primitives are covered in test files

The generated HLO file (transformer_stablehlo.txt) contains serialized StableHLO covering 14 unique TE FFI operations including gemm, attention, norm, softmax, quantization, and grouped operations.

Confidence Score: 4/5

  • This PR is safe to merge with minor fixes needed
  • Score reflects two logic/syntax issues that should be fixed: the fixture has an unused parameter that will cause errors, and missing error handling for unsupported dtypes. The approach is sound and adds valuable backwards compatibility testing.
  • Pay close attention to tests/jax/test_custom_call_compute.py - fix the fixture parameter and add dtype validation

Important Files Changed

Filename Overview
tests/jax/test_custom_call_compute.py Added comprehensive FFI compatibility tests with fixture issue in hlo_fixture parameter
tests/jax/ffi_hlo/transformer_stablehlo.txt Generated StableHLO text file containing serialized HLO for FFI compatibility testing

Sequence Diagram

sequenceDiagram
    participant Dev as Developer
    participant Test as test_generate_hlo
    participant Model as TransformerLayer/Model
    participant JAX as JAX Runtime
    participant XLA as XLA Compiler
    participant HLO as HLO Text File
    
    Note over Dev,HLO: Generate Phase (NVTE_JAX_FFI_HLO_GENERATE=1)
    Dev->>Test: Run test_generate_hlo()
    Test->>Model: Initialize Model with TransformerLayer
    Test->>Model: Call model with dummy inputs
    Model->>JAX: Execute TE operations (attention, norm, gemm, etc.)
    JAX->>XLA: Compile to StableHLO with FFI calls
    XLA->>HLO: Write module.mlir to dump directory
    Dev->>HLO: Copy to ffi_hlo/transformer_stablehlo.txt
    
    Note over Dev,HLO: Compatibility Test Phase
    participant Compat as test_ffi_compatibility
    participant Parser as _make_args_based_on_input_tensor_shape_and_dtype
    participant Backend as CUDA Backend
    
    Compat->>HLO: Load HLO text file
    Compat->>Parser: Parse @main signature
    Parser->>Parser: Extract tensor shapes and dtypes via regex
    Parser-->>Compat: Return dummy JAX arrays
    Compat->>Backend: Compile HLO with current FFI bindings
    Backend->>Backend: Verify FFI interfaces match
    Backend->>Backend: Execute with dummy inputs
    Backend-->>Compat: Return results (or error if incompatible)
    
    Note over Dev,HLO: Coverage Verification Phase
    participant Coverage as test_all_primitive_ffi_tested
    participant FFIReg as FFI Registrations
    
    Coverage->>HLO: Scan all HLO files for custom_call
    Coverage->>Coverage: Extract FFI names via regex
    Coverage->>FFIReg: Get all registered FFI primitives
    Coverage->>Coverage: Compare sets and check coverage
    Coverage-->>Dev: Assert all primitives tested
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. tests/jax/test_custom_call_compute.py, line 1934 (link)

    syntax: fixture method should have self as first parameter, not shape

  2. tests/jax/test_custom_call_compute.py, line 2045 (link)

    logic: regex captures too much - includes return type. Should stop before ->:

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

huanghua1994
huanghua1994 previously approved these changes Dec 16, 2025
Copy link
Collaborator

@huanghua1994 huanghua1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

2 similar comments
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. tests/jax/test_custom_call_compute.py, line 1935 (link)

    syntax: fixture should not have a parameter shape - it's not used in the function

  2. tests/jax/test_custom_call_compute.py, line 2076 (link)

    logic: if dtype_str is not in dtype_map, dtype will be None, causing jnp.ones on line 2078 to fail

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants