diff --git a/ODML/jax_to_tflite.ipynb b/ODML/jax_to_tflite.ipynb new file mode 100644 index 0000000..ab5b3ca --- /dev/null +++ b/ODML/jax_to_tflite.ipynb @@ -0,0 +1,475 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "8vD3L4qeREvg" + }, + "source": [ + "##### Copyright 2021 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qLCxmWRyRMZE" + }, + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4k5PoHrgJQOU" + }, + "source": [ + "# Jax Model Conversion For TFLite\n", + "## Overview\n", + "Note: This API is new and only available via pip install tf-nightly. It will be available in TensorFlow version 2.7. Also, the API is still experimental and subject to changes.\n", + "\n", + "This CodeLab demonstrates how to build a model for MNIST recognition using Jax, and how to convert it to TensorFlow Lite. This codelab will also demonstrate how to optimize the Jax-converted TFLite model with post-training quantiztion." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i8cfOBcjSByO" + }, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lq-T8XZMJ-zv" + }, + "source": [ + "## Prerequisites\n", + "It's recommended to try this feature with the newest TensorFlow nightly pip build." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "EV04hKdrnE4f" + }, + "source": [ + "!pip install tf-nightly --upgrade\n", + "!pip install jax --upgrade" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Make sure your JAX version is at least 0.4.20 or above.\n", + "import jax\n", + "jax.__version__" + ], + "metadata": { + "id": "vsilblGuGQa2", + "outputId": "d2a08dc2-0c83-4792-ecc9-3178b91d43f6", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'0.4.21'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 3 + } + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install orbax-export --upgrade" + ], + "metadata": { + "id": "PJeQhMUwH0oX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from orbax.export import ExportManager\n", + "from orbax.export import JaxModule\n", + "from orbax.export import ServingConfig" + ], + "metadata": { + "id": "j9_CVA0THQNc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QAeY43k9KM55" + }, + "source": [ + "## Data Preparation\n", + "Download the MNIST data with Keras dataset and pre-process." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qSOPSZJn1_Tj" + }, + "source": [ + "import numpy as np\n", + "import tensorflow as tf\n", + "import functools\n", + "\n", + "import time\n", + "import itertools\n", + "\n", + "import numpy.random as npr\n", + "\n", + "import jax.numpy as jnp\n", + "from jax import jit, grad, random\n", + "from jax.example_libraries import optimizers\n", + "from jax.example_libraries import stax\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "hdJIt3Da2Qn1" + }, + "source": [ + "def _one_hot(x, k, dtype=np.float32):\n", + " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", + " return np.array(x[:, None] == np.arange(k), dtype)\n", + "\n", + "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n", + "train_images, test_images = train_images / 255.0, test_images / 255.0\n", + "train_images = train_images.astype(np.float32)\n", + "test_images = test_images.astype(np.float32)\n", + "\n", + "train_labels = _one_hot(train_labels, 10)\n", + "test_labels = _one_hot(test_labels, 10)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0eFhx85YKlEY" + }, + "source": [ + "## Build the MNIST model with Jax" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "mi3TKB9nnQdK" + }, + "source": [ + "def loss(params, batch):\n", + " inputs, targets = batch\n", + " preds = predict(params, inputs)\n", + " return -jnp.mean(jnp.sum(preds * targets, axis=1))\n", + "\n", + "def accuracy(params, batch):\n", + " inputs, targets = batch\n", + " target_class = jnp.argmax(targets, axis=1)\n", + " predicted_class = jnp.argmax(predict(params, inputs), axis=1)\n", + " return jnp.mean(predicted_class == target_class)\n", + "\n", + "init_random_params, predict = stax.serial(\n", + " stax.Flatten,\n", + " stax.Dense(1024), stax.Relu,\n", + " stax.Dense(1024), stax.Relu,\n", + " stax.Dense(10), stax.LogSoftmax)\n", + "\n", + "rng = random.PRNGKey(0)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bRtnOBdJLd63" + }, + "source": [ + "## Train & Evaluate the model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "SWbYRyj7LYZt" + }, + "source": [ + "step_size = 0.001\n", + "num_epochs = 10\n", + "batch_size = 128\n", + "momentum_mass = 0.9\n", + "\n", + "\n", + "num_train = train_images.shape[0]\n", + "num_complete_batches, leftover = divmod(num_train, batch_size)\n", + "num_batches = num_complete_batches + bool(leftover)\n", + "\n", + "def data_stream():\n", + " rng = npr.RandomState(0)\n", + " while True:\n", + " perm = rng.permutation(num_train)\n", + " for i in range(num_batches):\n", + " batch_idx = perm[i * batch_size:(i + 1) * batch_size]\n", + " yield train_images[batch_idx], train_labels[batch_idx]\n", + "batches = data_stream()\n", + "\n", + "opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)\n", + "\n", + "@jit\n", + "def update(i, opt_state, batch):\n", + " params = get_params(opt_state)\n", + " return opt_update(i, grad(loss)(params, batch), opt_state)\n", + "\n", + "_, init_params = init_random_params(rng, (-1, 28 * 28))\n", + "opt_state = opt_init(init_params)\n", + "itercount = itertools.count()\n", + "\n", + "print(\"\\nStarting training...\")\n", + "for epoch in range(num_epochs):\n", + " start_time = time.time()\n", + " for _ in range(num_batches):\n", + " opt_state = update(next(itercount), opt_state, next(batches))\n", + " epoch_time = time.time() - start_time\n", + "\n", + " params = get_params(opt_state)\n", + " train_acc = accuracy(params, (train_images, train_labels))\n", + " test_acc = accuracy(params, (test_images, test_labels))\n", + " print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n", + " print(\"Training set accuracy {}\".format(train_acc))\n", + " print(\"Test set accuracy {}\".format(test_acc))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7Y1OZBhfQhOj" + }, + "source": [ + "## Convert to TFLite model.\n", + "Note here, we\n", + "1. Export the `JAX` model to `TF SavedModel` using `orbax`.\n", + "2. Call TFLite converter API to convert the `TF SavedModel` to `.tflite` model:\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "6pcqKZqdNTmn" + }, + "source": [ + "jax_module = JaxModule(params, predict, input_polymorphic_shape='b, ...')\n", + "converter = tf.lite.TFLiteConverter.from_concrete_functions(\n", + " [\n", + " jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n", + " tf.TensorSpec(shape=(1, 28, 28), dtype=tf.float32, name=\"input\")\n", + " )\n", + " ]\n", + ")\n", + "\n", + "tflite_model = converter.convert()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sqEhzaJPSPS1" + }, + "source": [ + "## Check the Converted TFLite Model\n", + "Compare the converted model's results with the Jax model." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "acj2AYzjSlaY" + }, + "source": [ + "serving_func = functools.partial(predict, params)\n", + "expected = serving_func(train_images[0:1])\n", + "\n", + "# Run the model with TensorFlow Lite\n", + "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n", + "interpreter.allocate_tensors()\n", + "input_details = interpreter.get_input_details()\n", + "output_details = interpreter.get_output_details()\n", + "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n", + "interpreter.invoke()\n", + "result = interpreter.get_tensor(output_details[0][\"index\"])\n", + "\n", + "# Assert if the result of TFLite model is consistent with the JAX model.\n", + "np.testing.assert_almost_equal(expected, result, 1e-5)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Qy9Gp4H2SjBL" + }, + "source": [ + "## Optimize the Model\n", + "We will provide a `representative_dataset` to do post-training quantiztion to optimize the model.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "KI0rLV-Meg-2" + }, + "source": [ + "def representative_dataset():\n", + " for i in range(1000):\n", + " x = train_images[i:i+1]\n", + " yield [x]\n", + "x_input = jnp.zeros((1, 28, 28))\n", + "converter = tf.lite.TFLiteConverter.experimental_from_jax(\n", + " [serving_func], [[('x', x_input)]])\n", + "tflite_model = converter.convert()\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", + "converter.representative_dataset = representative_dataset\n", + "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n", + "tflite_quant_model = converter.convert()\n", + "with open('jax_mnist_quant.tflite', 'wb') as f:\n", + " f.write(tflite_quant_model)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "15xQR3JZS8TV" + }, + "source": [ + "## Evaluate the Optimized Model" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "X3oOm0OaevD6" + }, + "source": [ + "expected = serving_func(train_images[0:1])\n", + "\n", + "# Run the model with TensorFlow Lite\n", + "interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)\n", + "interpreter.allocate_tensors()\n", + "input_details = interpreter.get_input_details()\n", + "output_details = interpreter.get_output_details()\n", + "interpreter.set_tensor(input_details[0][\"index\"], train_images[0:1, :, :])\n", + "interpreter.invoke()\n", + "result = interpreter.get_tensor(output_details[0][\"index\"])\n", + "\n", + "# Assert if the result of TFLite model is consistent with the Jax model.\n", + "np.testing.assert_almost_equal(expected, result, 1e-5)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QqHXCNa3myor" + }, + "source": [ + "## Compare the Quantized Model size\n", + "We should be able to see the quantized model is four times smaller than the original model." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "imFPw007juVG" + }, + "source": [ + "!du -h jax_mnist.tflite\n", + "!du -h jax_mnist_quant.tflite" + ], + "execution_count": null, + "outputs": [] + } + ] +}