|
| 1 | +## An MNIST example with single-program multiple-data (SPMD) data parallelism. |
| 2 | + |
| 3 | +The aim here is to illustrate how to use JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) to express and execute |
| 4 | +[SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) programs for data parallelism along a batch dimension, while also |
| 5 | +minimizing dependencies by avoiding the use of higher-level layers and |
| 6 | +optimizers libraries. |
| 7 | + |
| 8 | +Adapted from https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py. |
| 9 | + |
| 10 | +```bash |
| 11 | +$ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml |
| 12 | +``` |
| 13 | + |
| 14 | +--- |
| 15 | + |
| 16 | +```bash |
| 17 | +$ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist |
| 18 | +``` |
| 19 | + |
| 20 | +``` |
| 21 | +NAME READY STATUS RESTARTS AGE |
| 22 | +jaxjob-mnist-worker-0 0/1 Completed 0 108m |
| 23 | +jaxjob-mnist-worker-1 0/1 Completed 0 108m |
| 24 | +``` |
| 25 | + |
| 26 | +--- |
| 27 | +```bash |
| 28 | +$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow) |
| 29 | +$ kubectl logs -f ${PODNAME} -n kubeflow |
| 30 | +``` |
| 31 | + |
| 32 | +``` |
| 33 | +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/ |
| 34 | +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/ |
| 35 | +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/ |
| 36 | +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/ |
| 37 | +JAX global devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=131072), CpuDevice(id=131073), CpuDevice(id=131074), CpuDevice(id=131075), CpuDevice(id=131076), CpuDevice(id=131077), CpuDevice(id=131078), CpuDevice(id=131079)] |
| 38 | +JAX local devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)] |
| 39 | +JAX device count:16 |
| 40 | +JAX local device count:8 |
| 41 | +JAX process count:2 |
| 42 | +Epoch 0 in 1809.25 sec |
| 43 | +Training set accuracy 0.09871666878461838 |
| 44 | +Test set accuracy 0.09799999743700027 |
| 45 | +Epoch 1 in 0.51 sec |
| 46 | +Training set accuracy 0.09871666878461838 |
| 47 | +Test set accuracy 0.09799999743700027 |
| 48 | +Epoch 2 in 0.69 sec |
| 49 | +Training set accuracy 0.09871666878461838 |
| 50 | +Test set accuracy 0.09799999743700027 |
| 51 | +Epoch 3 in 0.81 sec |
| 52 | +Training set accuracy 0.09871666878461838 |
| 53 | +Test set accuracy 0.09799999743700027 |
| 54 | +Epoch 4 in 0.91 sec |
| 55 | +Training set accuracy 0.09871666878461838 |
| 56 | +Test set accuracy 0.09799999743700027 |
| 57 | +Epoch 5 in 0.97 sec |
| 58 | +Training set accuracy 0.09871666878461838 |
| 59 | +Test set accuracy 0.09799999743700027 |
| 60 | +Epoch 6 in 1.12 sec |
| 61 | +Training set accuracy 0.09035000205039978 |
| 62 | +Test set accuracy 0.08919999748468399 |
| 63 | +Epoch 7 in 1.11 sec |
| 64 | +Training set accuracy 0.09871666878461838 |
| 65 | +Test set accuracy 0.09799999743700027 |
| 66 | +Epoch 8 in 1.21 sec |
| 67 | +Training set accuracy 0.09871666878461838 |
| 68 | +Test set accuracy 0.09799999743700027 |
| 69 | +Epoch 9 in 1.29 sec |
| 70 | +Training set accuracy 0.09871666878461838 |
| 71 | +Test set accuracy 0.09799999743700027 |
| 72 | +
|
| 73 | +``` |
| 74 | + |
| 75 | +--- |
| 76 | + |
| 77 | +```bash |
| 78 | +$ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow |
| 79 | +``` |
| 80 | + |
| 81 | +``` |
| 82 | +apiVersion: kubeflow.org/v1 |
| 83 | +kind: JAXJob |
| 84 | +metadata: |
| 85 | + annotations: |
| 86 | + kubectl.kubernetes.io/last-applied-configuration: | |
| 87 | + {"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-mnist","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"image":"docker.io/sandipanify/jaxjob-spmd-mnist:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}} |
| 88 | + creationTimestamp: "2024-12-18T16:47:28Z" |
| 89 | + generation: 1 |
| 90 | + name: jaxjob-mnist |
| 91 | + namespace: kubeflow |
| 92 | + resourceVersion: "3620" |
| 93 | + uid: 15f1db77-3326-405d-95e6-3d9a0d581611 |
| 94 | +spec: |
| 95 | + jaxReplicaSpecs: |
| 96 | + Worker: |
| 97 | + replicas: 2 |
| 98 | + restartPolicy: OnFailure |
| 99 | + template: |
| 100 | + spec: |
| 101 | + containers: |
| 102 | + - image: docker.io/sandipanify/jaxjob-spmd-mnist:latest |
| 103 | + imagePullPolicy: Always |
| 104 | + name: jax |
| 105 | +status: |
| 106 | + completionTime: "2024-12-18T17:22:11Z" |
| 107 | + conditions: |
| 108 | + - lastTransitionTime: "2024-12-18T16:47:28Z" |
| 109 | + lastUpdateTime: "2024-12-18T16:47:28Z" |
| 110 | + message: JAXJob jaxjob-mnist is created. |
| 111 | + reason: JAXJobCreated |
| 112 | + status: "True" |
| 113 | + type: Created |
| 114 | + - lastTransitionTime: "2024-12-18T16:50:57Z" |
| 115 | + lastUpdateTime: "2024-12-18T16:50:57Z" |
| 116 | + message: JAXJob kubeflow/jaxjob-mnist is running. |
| 117 | + reason: JAXJobRunning |
| 118 | + status: "False" |
| 119 | + type: Running |
| 120 | + - lastTransitionTime: "2024-12-18T17:22:11Z" |
| 121 | + lastUpdateTime: "2024-12-18T17:22:11Z" |
| 122 | + message: JAXJob kubeflow/jaxjob-mnist successfully completed. |
| 123 | + reason: JAXJobSucceeded |
| 124 | + status: "True" |
| 125 | + type: Succeeded |
| 126 | + replicaStatuses: |
| 127 | + Worker: |
| 128 | + selector: training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker |
| 129 | + succeeded: 2 |
| 130 | + startTime: "2024-12-18T16:47:28Z" |
| 131 | +
|
| 132 | +``` |
0 commit comments