Fix batched Cholesky OOM on ROCm by bypassing hipSolver's external hipMalloc#717
Open
FlemingH wants to merge 2 commits intoROCm:amd-mainfrom
Open
Fix batched Cholesky OOM on ROCm by bypassing hipSolver's external hipMalloc#717FlemingH wants to merge 2 commits intoROCm:amd-mainfrom
FlemingH wants to merge 2 commits intoROCm:amd-mainfrom
Conversation
…pMalloc Made-with: Cursor
hawkinsp
reviewed
Feb 27, 2026
| return ffi::Error::Success(); | ||
| } | ||
|
|
||
| #ifdef JAX_GPU_HIP |
There was a problem hiding this comment.
Is it possible to just reuse the CUDA case with the appropriate defines? That's usually how it works since the APIs are almost identical.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
jax.vmap(jnp.linalg.cholesky) with batch >= 2 crashes with OOM on ROCm.
Root cause: JAX calls hipsolverDnXpotrfBatched (dense API, no workspace parameter), which internally allocates workspace via hipMalloc. This bypasses XLA's BFC allocator, and since XLA preallocates ~75% of GPU VRAM by default, the external hipMalloc fails.
Fix: Switch to hipsolverXpotrfBatched (standard API), which accepts an external workspace buffer. The workspace is allocated through XLA's scratch allocator, keeping all GPU memory within XLA's control. CUDA path is unchanged.