diff --git a/ODML/jax_to_tflite.ipynb b/ODML/jax_to_tflite.ipynb
new file mode 100644
index 0000000..ab5b3ca
--- /dev/null
+++ b/ODML/jax_to_tflite.ipynb
@@ -0,0 +1,475 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8vD3L4qeREvg"
+ },
+ "source": [
+ "##### Copyright 2021 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "qLCxmWRyRMZE"
+ },
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4k5PoHrgJQOU"
+ },
+ "source": [
+ "# Jax Model Conversion For TFLite\n",
+ "## Overview\n",
+ "Note: This API is new and only available via pip install tf-nightly. It will be available in TensorFlow version 2.7. Also, the API is still experimental and subject to changes.\n",
+ "\n",
+ "This CodeLab demonstrates how to build a model for MNIST recognition using Jax, and how to convert it to TensorFlow Lite. This codelab will also demonstrate how to optimize the Jax-converted TFLite model with post-training quantiztion."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "i8cfOBcjSByO"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lq-T8XZMJ-zv"
+ },
+ "source": [
+ "## Prerequisites\n",
+ "It's recommended to try this feature with the newest TensorFlow nightly pip build."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "EV04hKdrnE4f"
+ },
+ "source": [
+ "!pip install tf-nightly --upgrade\n",
+ "!pip install jax --upgrade"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Make sure your JAX version is at least 0.4.20 or above.\n",
+ "import jax\n",
+ "jax.__version__"
+ ],
+ "metadata": {
+ "id": "vsilblGuGQa2",
+ "outputId": "d2a08dc2-0c83-4792-ecc9-3178b91d43f6",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ }
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "'0.4.21'"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ }
+ },
+ "metadata": {},
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install orbax-export --upgrade"
+ ],
+ "metadata": {
+ "id": "PJeQhMUwH0oX"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from orbax.export import ExportManager\n",
+ "from orbax.export import JaxModule\n",
+ "from orbax.export import ServingConfig"
+ ],
+ "metadata": {
+ "id": "j9_CVA0THQNc"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QAeY43k9KM55"
+ },
+ "source": [
+ "## Data Preparation\n",
+ "Download the MNIST data with Keras dataset and pre-process."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "qSOPSZJn1_Tj"
+ },
+ "source": [
+ "import numpy as np\n",
+ "import tensorflow as tf\n",
+ "import functools\n",
+ "\n",
+ "import time\n",
+ "import itertools\n",
+ "\n",
+ "import numpy.random as npr\n",
+ "\n",
+ "import jax.numpy as jnp\n",
+ "from jax import jit, grad, random\n",
+ "from jax.example_libraries import optimizers\n",
+ "from jax.example_libraries import stax\n"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "hdJIt3Da2Qn1"
+ },
+ "source": [
+ "def _one_hot(x, k, dtype=np.float32):\n",
+ " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
+ " return np.array(x[:, None] == np.arange(k), dtype)\n",
+ "\n",
+ "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
+ "train_images, test_images = train_images / 255.0, test_images / 255.0\n",
+ "train_images = train_images.astype(np.float32)\n",
+ "test_images = test_images.astype(np.float32)\n",
+ "\n",
+ "train_labels = _one_hot(train_labels, 10)\n",
+ "test_labels = _one_hot(test_labels, 10)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0eFhx85YKlEY"
+ },
+ "source": [
+ "## Build the MNIST model with Jax"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "mi3TKB9nnQdK"
+ },
+ "source": [
+ "def loss(params, batch):\n",
+ " inputs, targets = batch\n",
+ " preds = predict(params, inputs)\n",
+ " return -jnp.mean(jnp.sum(preds * targets, axis=1))\n",
+ "\n",
+ "def accuracy(params, batch):\n",
+ " inputs, targets = batch\n",
+ " target_class = jnp.argmax(targets, axis=1)\n",
+ " predicted_class = jnp.argmax(predict(params, inputs), axis=1)\n",
+ " return jnp.mean(predicted_class == target_class)\n",
+ "\n",
+ "init_random_params, predict = stax.serial(\n",
+ " stax.Flatten,\n",
+ " stax.Dense(1024), stax.Relu,\n",
+ " stax.Dense(1024), stax.Relu,\n",
+ " stax.Dense(10), stax.LogSoftmax)\n",
+ "\n",
+ "rng = random.PRNGKey(0)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bRtnOBdJLd63"
+ },
+ "source": [
+ "## Train & Evaluate the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "SWbYRyj7LYZt"
+ },
+ "source": [
+ "step_size = 0.001\n",
+ "num_epochs = 10\n",
+ "batch_size = 128\n",
+ "momentum_mass = 0.9\n",
+ "\n",
+ "\n",
+ "num_train = train_images.shape[0]\n",
+ "num_complete_batches, leftover = divmod(num_train, batch_size)\n",
+ "num_batches = num_complete_batches + bool(leftover)\n",
+ "\n",
+ "def data_stream():\n",
+ " rng = npr.RandomState(0)\n",
+ " while True:\n",
+ " perm = rng.permutation(num_train)\n",
+ " for i in range(num_batches):\n",
+ " batch_idx = perm[i * batch_size:(i + 1) * batch_size]\n",
+ " yield train_images[batch_idx], train_labels[batch_idx]\n",
+ "batches = data_stream()\n",
+ "\n",
+ "opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)\n",
+ "\n",
+ "@jit\n",
+ "def update(i, opt_state, batch):\n",
+ " params = get_params(opt_state)\n",
+ " return opt_update(i, grad(loss)(params, batch), opt_state)\n",
+ "\n",
+ "_, init_params = init_random_params(rng, (-1, 28 * 28))\n",
+ "opt_state = opt_init(init_params)\n",
+ "itercount = itertools.count()\n",
+ "\n",
+ "print(\"\\nStarting training...\")\n",
+ "for epoch in range(num_epochs):\n",
+ " start_time = time.time()\n",
+ " for _ in range(num_batches):\n",
+ " opt_state = update(next(itercount), opt_state, next(batches))\n",
+ " epoch_time = time.time() - start_time\n",
+ "\n",
+ " params = get_params(opt_state)\n",
+ " train_acc = accuracy(params, (train_images, train_labels))\n",
+ " test_acc = accuracy(params, (test_images, test_labels))\n",
+ " print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n",
+ " print(\"Training set accuracy {}\".format(train_acc))\n",
+ " print(\"Test set accuracy {}\".format(test_acc))"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7Y1OZBhfQhOj"
+ },
+ "source": [
+ "## Convert to TFLite model.\n",
+ "Note here, we\n",
+ "1. Export the `JAX` model to `TF SavedModel` using `orbax`.\n",
+ "2. Call TFLite converter API to convert the `TF SavedModel` to `.tflite` model:\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "6pcqKZqdNTmn"
+ },
+ "source": [
+ "jax_module = JaxModule(params, predict, input_polymorphic_shape='b, ...')\n",
+ "converter = tf.lite.TFLiteConverter.from_concrete_functions(\n",
+ " [\n",
+ " jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n",
+ " tf.TensorSpec(shape=(1, 28, 28), dtype=tf.float32, name=\"input\")\n",
+ " )\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "tflite_model = converter.convert()"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "sqEhzaJPSPS1"
+ },
+ "source": [
+ "## Check the Converted TFLite Model\n",
+ "Compare the converted model's results with the Jax model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "acj2AYzjSlaY"
+ },
+ "source": [
+ "serving_func = functools.partial(predict, params)\n",
+ "expected = serving_func(train_images[0:1])\n",
+ "\n",
+ "# Run the model with TensorFlow Lite\n",
+ "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
+ "interpreter.allocate_tensors()\n",
+ "input_details = interpreter.get_input_details()\n",
+ "output_details = interpreter.get_output_details()\n",
+ "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n",
+ "interpreter.invoke()\n",
+ "result = interpreter.get_tensor(output_details[0][\"index\"])\n",
+ "\n",
+ "# Assert if the result of TFLite model is consistent with the JAX model.\n",
+ "np.testing.assert_almost_equal(expected, result, 1e-5)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Qy9Gp4H2SjBL"
+ },
+ "source": [
+ "## Optimize the Model\n",
+ "We will provide a `representative_dataset` to do post-training quantiztion to optimize the model.\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "KI0rLV-Meg-2"
+ },
+ "source": [
+ "def representative_dataset():\n",
+ " for i in range(1000):\n",
+ " x = train_images[i:i+1]\n",
+ " yield [x]\n",
+ "x_input = jnp.zeros((1, 28, 28))\n",
+ "converter = tf.lite.TFLiteConverter.experimental_from_jax(\n",
+ " [serving_func], [[('x', x_input)]])\n",
+ "tflite_model = converter.convert()\n",
+ "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
+ "converter.representative_dataset = representative_dataset\n",
+ "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n",
+ "tflite_quant_model = converter.convert()\n",
+ "with open('jax_mnist_quant.tflite', 'wb') as f:\n",
+ " f.write(tflite_quant_model)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "15xQR3JZS8TV"
+ },
+ "source": [
+ "## Evaluate the Optimized Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "X3oOm0OaevD6"
+ },
+ "source": [
+ "expected = serving_func(train_images[0:1])\n",
+ "\n",
+ "# Run the model with TensorFlow Lite\n",
+ "interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)\n",
+ "interpreter.allocate_tensors()\n",
+ "input_details = interpreter.get_input_details()\n",
+ "output_details = interpreter.get_output_details()\n",
+ "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n",
+ "interpreter.invoke()\n",
+ "result = interpreter.get_tensor(output_details[0][\"index\"])\n",
+ "\n",
+ "# Assert if the result of TFLite model is consistent with the Jax model.\n",
+ "np.testing.assert_almost_equal(expected, result, 1e-5)"
+ ],
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QqHXCNa3myor"
+ },
+ "source": [
+ "## Compare the Quantized Model size\n",
+ "We should be able to see the quantized model is four times smaller than the original model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "imFPw007juVG"
+ },
+ "source": [
+ "!du -h jax_mnist.tflite\n",
+ "!du -h jax_mnist_quant.tflite"
+ ],
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}