Skip to content

Commit 6d58ea9

Browse files
Testing CI in JAX example (#2385)
* Add MNIST example with SPMD for JAX Illustrate how to use JAX's `pmap` to express and execute single-program multiple-data (SPMD) programs for data parallelism along a batch dimension Signed-off-by: Sandipan Panda <[email protected]> * Update CONTRIBUTING.md Use -- server-side to install the latest local changes of Training Operator control plane Signed-off-by: Sandipan Panda <[email protected]> * Add JAXJob output Signed-off-by: Sandipan Panda <[email protected]> * Update JAXJob CI images Signed-off-by: Sandipan Panda <[email protected]> * Adjust jaxjob spmd example batch size Signed-off-by: Sandipan Panda <[email protected]> * Add JAX Example Docker Image Build in CI Signed-off-by: sailesh duddupudi <[email protected]> * Fix script name typo Signed-off-by: sailesh duddupudi <[email protected]> * Update script permissions Signed-off-by: sailesh duddupudi <[email protected]> * Add KIND_CLUSTER env var Signed-off-by: sailesh duddupudi <[email protected]> * Increase timeouts Signed-off-by: sailesh duddupudi <[email protected]> * Test higher resources Signed-off-by: sailesh duddupudi <[email protected]> * Increase Timeout Signed-off-by: sailesh duddupudi <[email protected]> * remove resource reqs Signed-off-by: sailesh duddupudi <[email protected]> * test low batch size Signed-off-by: sailesh duddupudi <[email protected]> * test small batch size Signed-off-by: sailesh duddupudi <[email protected]> * Hardcode number of batches Signed-off-by: sailesh duddupudi <[email protected]> --------- Signed-off-by: Sandipan Panda <[email protected]> Signed-off-by: sailesh duddupudi <[email protected]> Co-authored-by: Sandipan Panda <[email protected]>
1 parent 1dfa40c commit 6d58ea9

File tree

13 files changed

+494
-7
lines changed

13 files changed

+494
-7
lines changed

.github/workflows/integration-tests.yaml

+14
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,26 @@ jobs:
6565
python-version: ${{ matrix.python-version }}
6666
gang-scheduler-name: ${{ matrix.gang-scheduler-name }}
6767

68+
- name: Build JAX Job Example Image
69+
run: |
70+
./scripts/gha/build-jax-mnist-image.sh
71+
env:
72+
JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test
73+
74+
- name: Load JAX Job Example Image
75+
run: |
76+
kind load docker-image ${{ env.JAX_JOB_CI_IMAGE }} --name ${{ env.KIND_CLUSTER }}
77+
env:
78+
KIND_CLUSTER: training-operator-cluster
79+
JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test
80+
6881
- name: Run tests
6982
run: |
7083
pip install pytest
7184
python3 -m pip install -e sdk/python; pytest -s sdk/python/test/e2e --log-cli-level=debug --namespace=default
7285
env:
7386
GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }}
87+
JAX_JOB_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test
7488

7589
- name: Collect volcano logs
7690
if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }}

.github/workflows/publish-example-images.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,7 @@ jobs:
7474
platforms: linux/amd64
7575
dockerfile: examples/pytorch/deepspeed-demo/Dockerfile
7676
context: examples/pytorch/deepspeed-demo
77+
- component-name: jaxjob-dist-spmd-mnist
78+
platforms: linux/amd64,linux/arm64
79+
dockerfile: examples/jax/jax-dist-spmd-mnist/Dockerfile
80+
context: examples/jax/jax-dist-spmd-mnist/

CONTRIBUTING.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Note, that for the example job below, the PyTorchJob uses the `kubeflow` namespa
6666

6767
From here we can apply the manifests to the cluster.
6868
```sh
69-
kubectl apply -k "github.com/kubeflow/training-operator/manifests/overlays/standalone"
69+
kubectl apply --server-side -k "github.com/kubeflow/training-operator/manifests/overlays/standalone"
7070
```
7171

7272
Then we can patch it with the latest operator image.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
FROM python:3.13
2+
3+
RUN pip install --upgrade pip
4+
RUN pip install --upgrade jax[k8s] absl-py
5+
6+
RUN apt-get update && apt-get install -y \
7+
build-essential \
8+
cmake \
9+
git \
10+
libgoogle-glog-dev \
11+
libgflags-dev \
12+
libprotobuf-dev \
13+
protobuf-compiler \
14+
&& rm -rf /var/lib/apt/lists/*
15+
16+
RUN git clone https://github.com/facebookincubator/gloo.git \
17+
&& cd gloo \
18+
&& git checkout 43b7acbf372cdce14075f3526e39153b7e433b53 \
19+
&& mkdir build \
20+
&& cd build \
21+
&& cmake ../ \
22+
&& make \
23+
&& make install
24+
25+
WORKDIR /app
26+
27+
ADD datasets.py spmd_mnist_classifier_fromscratch.py /app/
28+
29+
ENTRYPOINT ["python3", "spmd_mnist_classifier_fromscratch.py"]
+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
## An MNIST example with single-program multiple-data (SPMD) data parallelism.
2+
3+
The aim here is to illustrate how to use JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) to express and execute
4+
[SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) programs for data parallelism along a batch dimension, while also
5+
minimizing dependencies by avoiding the use of higher-level layers and
6+
optimizers libraries.
7+
8+
Adapted from https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py.
9+
10+
```bash
11+
$ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml
12+
```
13+
14+
---
15+
16+
```bash
17+
$ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist
18+
```
19+
20+
```
21+
NAME READY STATUS RESTARTS AGE
22+
jaxjob-mnist-worker-0 0/1 Completed 0 108m
23+
jaxjob-mnist-worker-1 0/1 Completed 0 108m
24+
```
25+
26+
---
27+
```bash
28+
$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow)
29+
$ kubectl logs -f ${PODNAME} -n kubeflow
30+
```
31+
32+
```
33+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/
34+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/
35+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/
36+
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/
37+
JAX global devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=131072), CpuDevice(id=131073), CpuDevice(id=131074), CpuDevice(id=131075), CpuDevice(id=131076), CpuDevice(id=131077), CpuDevice(id=131078), CpuDevice(id=131079)]
38+
JAX local devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
39+
JAX device count:16
40+
JAX local device count:8
41+
JAX process count:2
42+
Epoch 0 in 1809.25 sec
43+
Training set accuracy 0.09871666878461838
44+
Test set accuracy 0.09799999743700027
45+
Epoch 1 in 0.51 sec
46+
Training set accuracy 0.09871666878461838
47+
Test set accuracy 0.09799999743700027
48+
Epoch 2 in 0.69 sec
49+
Training set accuracy 0.09871666878461838
50+
Test set accuracy 0.09799999743700027
51+
Epoch 3 in 0.81 sec
52+
Training set accuracy 0.09871666878461838
53+
Test set accuracy 0.09799999743700027
54+
Epoch 4 in 0.91 sec
55+
Training set accuracy 0.09871666878461838
56+
Test set accuracy 0.09799999743700027
57+
Epoch 5 in 0.97 sec
58+
Training set accuracy 0.09871666878461838
59+
Test set accuracy 0.09799999743700027
60+
Epoch 6 in 1.12 sec
61+
Training set accuracy 0.09035000205039978
62+
Test set accuracy 0.08919999748468399
63+
Epoch 7 in 1.11 sec
64+
Training set accuracy 0.09871666878461838
65+
Test set accuracy 0.09799999743700027
66+
Epoch 8 in 1.21 sec
67+
Training set accuracy 0.09871666878461838
68+
Test set accuracy 0.09799999743700027
69+
Epoch 9 in 1.29 sec
70+
Training set accuracy 0.09871666878461838
71+
Test set accuracy 0.09799999743700027
72+
73+
```
74+
75+
---
76+
77+
```bash
78+
$ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow
79+
```
80+
81+
```
82+
apiVersion: kubeflow.org/v1
83+
kind: JAXJob
84+
metadata:
85+
annotations:
86+
kubectl.kubernetes.io/last-applied-configuration: |
87+
{"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-mnist","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"image":"docker.io/sandipanify/jaxjob-spmd-mnist:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}}
88+
creationTimestamp: "2024-12-18T16:47:28Z"
89+
generation: 1
90+
name: jaxjob-mnist
91+
namespace: kubeflow
92+
resourceVersion: "3620"
93+
uid: 15f1db77-3326-405d-95e6-3d9a0d581611
94+
spec:
95+
jaxReplicaSpecs:
96+
Worker:
97+
replicas: 2
98+
restartPolicy: OnFailure
99+
template:
100+
spec:
101+
containers:
102+
- image: docker.io/sandipanify/jaxjob-spmd-mnist:latest
103+
imagePullPolicy: Always
104+
name: jax
105+
status:
106+
completionTime: "2024-12-18T17:22:11Z"
107+
conditions:
108+
- lastTransitionTime: "2024-12-18T16:47:28Z"
109+
lastUpdateTime: "2024-12-18T16:47:28Z"
110+
message: JAXJob jaxjob-mnist is created.
111+
reason: JAXJobCreated
112+
status: "True"
113+
type: Created
114+
- lastTransitionTime: "2024-12-18T16:50:57Z"
115+
lastUpdateTime: "2024-12-18T16:50:57Z"
116+
message: JAXJob kubeflow/jaxjob-mnist is running.
117+
reason: JAXJobRunning
118+
status: "False"
119+
type: Running
120+
- lastTransitionTime: "2024-12-18T17:22:11Z"
121+
lastUpdateTime: "2024-12-18T17:22:11Z"
122+
message: JAXJob kubeflow/jaxjob-mnist successfully completed.
123+
reason: JAXJobSucceeded
124+
status: "True"
125+
type: Succeeded
126+
replicaStatuses:
127+
Worker:
128+
selector: training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker
129+
succeeded: 2
130+
startTime: "2024-12-18T16:47:28Z"
131+
132+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright 2018 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Datasets used in examples."""
16+
17+
18+
import array
19+
import gzip
20+
import os
21+
import struct
22+
import urllib.request
23+
from os import path
24+
25+
import numpy as np
26+
27+
_DATA = "/tmp/jax_example_data/"
28+
29+
30+
def _download(url, filename):
31+
"""Download a url to a file in the JAX data temp directory."""
32+
if not path.exists(_DATA):
33+
os.makedirs(_DATA)
34+
out_file = path.join(_DATA, filename)
35+
if not path.isfile(out_file):
36+
urllib.request.urlretrieve(url, out_file)
37+
print(f"downloaded {url} to {_DATA}")
38+
39+
40+
def _partial_flatten(x):
41+
"""Flatten all but the first dimension of an ndarray."""
42+
return np.reshape(x, (x.shape[0], -1))
43+
44+
45+
def _one_hot(x, k, dtype=np.float32):
46+
"""Create a one-hot encoding of x of size k."""
47+
return np.array(x[:, None] == np.arange(k), dtype)
48+
49+
50+
def mnist_raw():
51+
"""Download and parse the raw MNIST dataset."""
52+
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
53+
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
54+
55+
def parse_labels(filename):
56+
with gzip.open(filename, "rb") as fh:
57+
_ = struct.unpack(">II", fh.read(8))
58+
return np.array(array.array("B", fh.read()), dtype=np.uint8)
59+
60+
def parse_images(filename):
61+
with gzip.open(filename, "rb") as fh:
62+
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
63+
return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(
64+
num_data, rows, cols
65+
)
66+
67+
for filename in [
68+
"train-images-idx3-ubyte.gz",
69+
"train-labels-idx1-ubyte.gz",
70+
"t10k-images-idx3-ubyte.gz",
71+
"t10k-labels-idx1-ubyte.gz",
72+
]:
73+
_download(base_url + filename, filename)
74+
75+
train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
76+
train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
77+
test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
78+
test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))
79+
80+
return train_images, train_labels, test_images, test_labels
81+
82+
83+
def mnist(permute_train=False):
84+
"""Download, parse and process MNIST data to unit scale and one-hot labels."""
85+
train_images, train_labels, test_images, test_labels = mnist_raw()
86+
87+
train_images = _partial_flatten(train_images) / np.float32(255.0)
88+
test_images = _partial_flatten(test_images) / np.float32(255.0)
89+
train_labels = _one_hot(train_labels, 10)
90+
test_labels = _one_hot(test_labels, 10)
91+
92+
if permute_train:
93+
perm = np.random.RandomState(0).permutation(train_images.shape[0])
94+
train_images = train_images[perm]
95+
train_labels = train_labels[perm]
96+
97+
return train_images, train_labels, test_images, test_labels
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
apiVersion: "kubeflow.org/v1"
2+
kind: JAXJob
3+
metadata:
4+
name: jaxjob-mnist
5+
namespace: kubeflow
6+
spec:
7+
jaxReplicaSpecs:
8+
Worker:
9+
replicas: 2
10+
restartPolicy: OnFailure
11+
template:
12+
spec:
13+
containers:
14+
- name: jax
15+
image: docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest
16+
imagePullPolicy: Always

0 commit comments

Comments
 (0)