-
b2654c0?diff=split&w=0 |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 23 replies
-
Hi - thanks for the question! It looks like this change was purely on the JAX side, and not associated with any XLA commits. Pinging for more info @rmlarsen, who authored the change. |
Beta Was this translation helpful? Give feedback.
-
I found that NAN is checked at the end of the svd calculation, but before that, it might have returned early here: |
Beta Was this translation helpful? Give feedback.
-
When using JIT compilation there is no such thing as "returning early", so we decided to get of rid of the special casing to simplify the code and slightly speed up the common case. |
Beta Was this translation helpful? Give feedback.
-
Hi, is there any exact performance or case available for this optimization? I tried jax v0.4.26 (before opt) and jax v0.4.28 (after opt) with public jax UT
I looped each case 100 times to watch the performance, but there's no performance improvement after opt... |
Beta Was this translation helpful? Give feedback.
-
Hi @rmlarsen , @jakevdp , I did below changes and retested again:
It seems that the svd opt still impacted part of the GPU performance:
As you mentioned, jax v0.4.28 has other changes which impacted performance, but this svd opt still brought additional ~5 s regression here. |
Beta Was this translation helpful? Give feedback.
When using JIT compilation there is no such thing as "returning early", so we decided to get of rid of the special casing to simplify the code and slightly speed up the common case.