Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document if differentiability of implementation is dependent on hardware #23650

Open
unalmis opened this issue Sep 16, 2024 · 5 comments
Open
Labels
enhancement New feature or request

Comments

@unalmis
Copy link

unalmis commented Sep 16, 2024

On my machine, eigh_tridagonal is reverse differentiable on cpu. On another machine, it is not reverse differentiable on gpu, causing unexpected nan leaks. It would be nice if the documentation of a function mentions when the differentiability of the current implementation is dependent on hardware.

@unalmis unalmis added the enhancement New feature or request label Sep 16, 2024
@unalmis unalmis changed the title Document reverse-mode differentiablity of eigh_tridiagonal Document when differentiablity of implementation is dependent on hardware Sep 16, 2024
@unalmis unalmis changed the title Document when differentiablity of implementation is dependent on hardware Document if differentiability of implementation is dependent on hardware Sep 16, 2024
@dfm
Copy link
Member

dfm commented Sep 16, 2024

Can you provide a minimal, reproducible example that shows the problem, and defines clearly what you mean by "not reverse differentiable". In other words, do you see an error, or does it seem like there are different numerics on different platforms which lead to this inconsistency? Thanks.

@unalmis
Copy link
Author

unalmis commented Sep 16, 2024

A function that depends on the output of eigh_tridiagonal's eigenvalues gives a nan when jax.grad used on gpu. I believe this is expected due to current implementation

@dfm
Copy link
Member

dfm commented Sep 16, 2024

Like I suggested above, can you provide me with a function that I can run that demonstrates this and a pointer to where that belief is coming from? I'm not very familiar with this specific function, but I'm happy to try and debug or suggest where the docs should be updated if you can give me some more info!

@hawkinsp
Copy link
Member

The structure of the spectrum is probably important here. In general eigendecomposition is not differentiable if, e.g., you have repeated eigenvalues. And it's possible that it's very sensitive to small numerical differences, e.g., between hardware.

@PhilipVinc
Copy link
Contributor

Just a minor comment: I think eigendecomposition is in general differentiable (at least as long as the matrix is non singular).

For the case of degenerate spectrum, there are some complications but it is still 'defined'.
The problem is that the formula is more complicated and the operation is more costly, so no AD framework defaults to it.
For eigh I use this gist. For tridiagonal, you'd have to work your way through

https://gist.github.com/PhilipVinc/f92499294efb3787e123f89ace3c5b29

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants