Skip to content

Commit

Permalink
[TE/JAX] Update required JAX version for FFI custom calls with cudaGr…
Browse files Browse the repository at this point in the history
…aph (#1285)

Update jax version for ffi

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Oct 25, 2024
1 parent 7b284fe commit 7cef756
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def is_ffi_enabled():
"""
Helper function checking if XLA Custom Call with FFI is enabled
"""
is_supported = jax_version_meet_requirement("0.4.31")
is_supported = jax_version_meet_requirement("0.4.35")
# New APIs with FFI are enabled by default
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
Expand Down

0 comments on commit 7cef756

Please sign in to comment.