Replies: 1 comment
-
same problem i stuck in. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello, I am wondering the mechanism of how the GPU memory is used in jax.
As the
XLA_PYTHON_CLIENT_MEM_FRACTION
is set to 0.9 as default, the about 33G for each GPU VRAM is preallocated when the script starts in 2xA100-40G.When I load LLaMA 30b checkpoint (, which is 61GB for float16 precision), OOM occurs although it seems there is enough capacity to load the checkpoint.
So, I change the default
XLA_PYTHON_CLIENT_MEM_FRACTION
to 0.99, then I can successfully load the checkpoint.There are two hypothesis that can explain this phonomena.
XLA_PYTHON_CLIENT_MEM_FRACTION
.I think the latter is true, but I ask to this community for clarity.
Thank you
Beta Was this translation helpful? Give feedback.
All reactions