Skip to content

[Bug]: GPU not utilized in SPU when running GPT-2: CUDA-enabled jaxlib missing + Transformers JAX deprecation warnings #1351

@danxh136

Description

@danxh136

Issue Type

Support

Modules Involved

Documentation/Tutorial/Example

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

0.9.4

OS Platform and Distribution

Linux Ubuntu 24.04.3 LTS

Python Version

3.10.19

Compiler Version

No response

Current Behavior?

I am running a GPT-2 model inside the SPU environment (OpenBumbleBee / SecInfer), and I consistently encounter the following warnings during execution:
`TensorFlow and JAX classes are deprecated and will be removed in Transformers v5.
We recommend migrating to PyTorch classes or pinning your version of Transformers.

WARNING: jax._src.xla_bridge: An NVIDIA GPU may be present on this machine,
but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
`

Problem Summary

  • SPU detects a GPU on the system, but JAX falls back to CPU execution.
  • The installed jaxlib (0.6.2) is a CPU-only build, and SPU does not appear to load a CUDA-enabled version.
  • HuggingFace Transformers warns that TF/JAX models will be deprecated, suggesting migration to PyTorch.
  • As a result, GPT-2 inference inside SPU runs significantly slower than expected.

Questions

  1. Does SPU officially support GPU-accelerated JAX (CUDA-enabled jaxlib)?
    If so, what CUDA / jaxlib combinations are supported and documented?

  2. What is the recommended way to install a CUDA-enabled jaxlib inside the SPU Python environment?
    The standard wheels are CUDA-version-specific and may not match SPU’s Python/ABI constraints.

  3. Is there an official environment specification (environment.yml or requirements.txt) for running JAX + GPU inside SPU?

Environment

OS: Ubuntu 24.04.3 LTS
Python: 3.10.19
transformers: 4.57.3
jax: 0.6.2
jaxlib: 0.6.2 (CPU-only)
SPU: 0.9.4
GPU: NVIDIA GeForce RTX 4090
CUDA: 13.0

Thank you for your assistance!

Standalone code to reproduce the issue

print("A bug")

Relevant log output

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5.
We recommend migrating to PyTorch classes or pinning your version of Transformers.

WARNING: jax._src.xla_bridge: An NVIDIA GPU may be present on this machine, 
but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions