When I compute gradients and perform updates using the same values in Torch and JAX, I find that after multiple iterations, the inference results differ significantly. #23646
Labels
bug
Something isn't working
Description
Please specify cuda:0 at the very beginning.
System info (python version, jaxlib version, accelerator, etc.)
download the code:https://drive.google.com/file/d/1H8uPgPdslVpizmSsif6oK4ey2e-oum9x/view?usp=sharing
The text was updated successfully, but these errors were encountered: