Refactor UNet, Residual UNet, MIRNet2D, and Utils from PyTorch to JAX: - [x] [UNet](https://github.com/htem/raygun/blob/ce53ec95e2d90e52133d5f6ea2e30df7502e477c/raygun/jax/networks/UNet.py) - [x] [Residual UNet](https://github.com/htem/raygun/blob/ce53ec95e2d90e52133d5f6ea2e30df7502e477c/raygun/jax/networks/ResidualUNet.py) - [ ] [MIRNet2D](https://github.com/htem/raygun/blob/ce53ec95e2d90e52133d5f6ea2e30df7502e477c/raygun/jax/networks/MIRNet2D.py) - [x] [Utils](https://github.com/htem/raygun/blob/ce53ec95e2d90e52133d5f6ea2e30df7502e477c/raygun/jax/networks/utils.py)