-
Notifications
You must be signed in to change notification settings - Fork 140
Description
Issue Type
Support
Modules Involved
Documentation/Tutorial/Example
Have you reproduced the bug with SPU HEAD?
Yes
Have you searched existing issues?
Yes
SPU Version
0.9.4
OS Platform and Distribution
Linux Ubuntu 24.04.3 LTS
Python Version
3.10.19
Compiler Version
No response
Current Behavior?
I am running a GPT-2 model inside the SPU environment (OpenBumbleBee / SecInfer), and I consistently encounter the following warnings during execution:
`TensorFlow and JAX classes are deprecated and will be removed in Transformers v5.
We recommend migrating to PyTorch classes or pinning your version of Transformers.
WARNING: jax._src.xla_bridge: An NVIDIA GPU may be present on this machine,
but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
`
Problem Summary
- SPU detects a GPU on the system, but JAX falls back to CPU execution.
- The installed
jaxlib(0.6.2) is a CPU-only build, and SPU does not appear to load a CUDA-enabled version. - HuggingFace Transformers warns that TF/JAX models will be deprecated, suggesting migration to PyTorch.
- As a result, GPT-2 inference inside SPU runs significantly slower than expected.
Questions
-
Does SPU officially support GPU-accelerated JAX (CUDA-enabled
jaxlib)?
If so, what CUDA / jaxlib combinations are supported and documented? -
What is the recommended way to install a CUDA-enabled
jaxlibinside the SPU Python environment?
The standard wheels are CUDA-version-specific and may not match SPU’s Python/ABI constraints. -
Is there an official environment specification (environment.yml or requirements.txt) for running JAX + GPU inside SPU?
Environment
OS: Ubuntu 24.04.3 LTS
Python: 3.10.19
transformers: 4.57.3
jax: 0.6.2
jaxlib: 0.6.2 (CPU-only)
SPU: 0.9.4
GPU: NVIDIA GeForce RTX 4090
CUDA: 13.0
Thank you for your assistance!
Standalone code to reproduce the issue
print("A bug")Relevant log output
TensorFlow and JAX classes are deprecated and will be removed in Transformers v5.
We recommend migrating to PyTorch classes or pinning your version of Transformers.
WARNING: jax._src.xla_bridge: An NVIDIA GPU may be present on this machine,
but a CUDA-enabled jaxlib is not installed. Falling back to cpu.