-
Notifications
You must be signed in to change notification settings - Fork 583
[JAX] HLO FFI tests #2517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] HLO FFI tests #2517
Conversation
Signed-off-by: Jeremy Berchtold <[email protected]>
|
/te-ci jax |
Greptile SummaryAdded FFI backwards compatibility tests for JAX to prevent accidental breaking changes to FFI interfaces. The PR introduces three test methods:
The generated HLO file ( Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Dev as Developer
participant Test as test_generate_hlo
participant Model as TransformerLayer/Model
participant JAX as JAX Runtime
participant XLA as XLA Compiler
participant HLO as HLO Text File
Note over Dev,HLO: Generate Phase (NVTE_JAX_FFI_HLO_GENERATE=1)
Dev->>Test: Run test_generate_hlo()
Test->>Model: Initialize Model with TransformerLayer
Test->>Model: Call model with dummy inputs
Model->>JAX: Execute TE operations (attention, norm, gemm, etc.)
JAX->>XLA: Compile to StableHLO with FFI calls
XLA->>HLO: Write module.mlir to dump directory
Dev->>HLO: Copy to ffi_hlo/transformer_stablehlo.txt
Note over Dev,HLO: Compatibility Test Phase
participant Compat as test_ffi_compatibility
participant Parser as _make_args_based_on_input_tensor_shape_and_dtype
participant Backend as CUDA Backend
Compat->>HLO: Load HLO text file
Compat->>Parser: Parse @main signature
Parser->>Parser: Extract tensor shapes and dtypes via regex
Parser-->>Compat: Return dummy JAX arrays
Compat->>Backend: Compile HLO with current FFI bindings
Backend->>Backend: Verify FFI interfaces match
Backend->>Backend: Execute with dummy inputs
Backend-->>Compat: Return results (or error if incompatible)
Note over Dev,HLO: Coverage Verification Phase
participant Coverage as test_all_primitive_ffi_tested
participant FFIReg as FFI Registrations
Coverage->>HLO: Scan all HLO files for custom_call
Coverage->>Coverage: Extract FFI names via regex
Coverage->>FFIReg: Get all registered FFI primitives
Coverage->>Coverage: Compare sets and check coverage
Coverage-->>Dev: Assert all primitives tested
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
2 files reviewed, 2 comments
Signed-off-by: Jeremy Berchtold <[email protected]>
4fa4cd9 to
4551979
Compare
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 4 comments
huanghua1994
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
/te-ci jax |
2 similar comments
|
/te-ci jax |
|
/te-ci jax |
Signed-off-by: Jeremy Berchtold <[email protected]>
|
/te-ci jax |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
2 files reviewed, 2 comments
Description
Prevents TE/JAX from changing FFI interfaces and breaking backwards compatibility with older HLO on accident.
Type of change
Changes
test_custom_call.pyand associated HLO text file to ensure both of the followingChecklist: