Pinned Loading
-
distributed_kron
distributed_kron PublicAn implementation of PSGD Kron in JAX for distributed training in JAX or Flax
Python 1
-
kron_torch
kron_torch PublicAn implementation of PSGD Kron second-order optimizer for PyTorch
-
image-classification-jax
image-classification-jax PublicImage classification in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet
Python 2
Something went wrong, please refresh the page to try again.
If the problem persists, check the GitHub status page or contact support.
If the problem persists, check the GitHub status page or contact support.