Skip to content

Fix batched Cholesky OOM on ROCm by bypassing hipSolver's external hipMalloc#717

Open
FlemingH wants to merge 2 commits intoROCm:amd-mainfrom
FlemingH:amd-main
Open

Fix batched Cholesky OOM on ROCm by bypassing hipSolver's external hipMalloc#717
FlemingH wants to merge 2 commits intoROCm:amd-mainfrom
FlemingH:amd-main

Conversation

@FlemingH
Copy link

@FlemingH FlemingH commented Feb 26, 2026

jax.vmap(jnp.linalg.cholesky) with batch >= 2 crashes with OOM on ROCm.
Root cause: JAX calls hipsolverDnXpotrfBatched (dense API, no workspace parameter), which internally allocates workspace via hipMalloc. This bypasses XLA's BFC allocator, and since XLA preallocates ~75% of GPU VRAM by default, the external hipMalloc fails.
Fix: Switch to hipsolverXpotrfBatched (standard API), which accepts an external workspace buffer. The workspace is allocated through XLA's scratch allocator, keeping all GPU memory within XLA's control. CUDA path is unchanged.

return ffi::Error::Success();
}

#ifdef JAX_GPU_HIP

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to just reuse the CUDA case with the appropriate defines? That's usually how it works since the APIs are almost identical.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants