Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
keras-apache-mxnet update with bug fixes and new features (#148)
Browse files Browse the repository at this point in the history
* Add saving mxnet model checkpoint callback (#132)

* Add saving mxnet model checkpoint callback

* Add tests for MXNetModelCheckpoint callback

* Fix MXNetModelCheckpoint test case

* Fix MXNetModelCheckpoint tests. Add dependency on keras_applications and keras_preprocessing

* Fixed CR comments. Split tests into multiple independent tests

* Fix CR comments on the code documentation

* Add additional test to verify only one model is saved

* Add examples of monitors

* update pr and nightly buildspec,add into source control (#141)

* Fix batchnorm gamma (#137)

* fix gamma and beta equal to None

* fix style

* fix initializer, enable unit test

* update comments

* remove +, remove repeated install, add clear message (#142)

* fix conv1d channels first (#143)

* fix conv1d channels first

* update data format for causal test

* fix style

* Adding get_mxnet_model_info API to allow users to query underlying MXNet model info (#144)

* Adding get_mxnet_model_info API to allow users to query underlying MXNet model info

* resolve merge conflicts in conv1d

* Add more tests - functional model, compare with return values of save_mxnet_model API

* update save mxnet model API document. (#147)

Co-authored-by: Sandeep Krishnamurthy <[email protected]>

* implement sparse categorical crossentropy, enable unitests (#145)

* implement sparse categorical crossentropy, enable unitests

* fix elementwise operators, fix sparse categorical accuract, enabled unitests

* fix element wise opreators

* simplify using ndim

* fix style

* reduce number of returns

* fix operator name in error message

* update comments and doc string

* update comment spelling

* update documentation, improve CI test commands (#151)

* update documentation, improve CI test commands

* fix conv1d initialization, fix conv1d unit test

* fix documentation hyperlink

* update buildspec

* remove official keras installed with dependencies

* update ci commands
  • Loading branch information
roywei authored Aug 8, 2018
1 parent 332c62f commit c0f4ef6
Show file tree
Hide file tree
Showing 15 changed files with 565 additions and 77 deletions.
83 changes: 70 additions & 13 deletions docs/mxnet_backend/save_mxnet_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,42 @@
## Table of Contents

1. [Objective](#objective)
2. [Train and Save a Convolutional Neural Network (CNN) Model for MNIST Dataset](#train-and-save-a-convolutional-neural-network-cnn-model-for-mnist-dataset)
3. [Import the model in MXNet for Inference](#import-the-model-in-mxnet-for-inference)
4. [References](#references)
2. [Export MXNet model using save_mxnet_model() API](#export-mxnet-model-using-save_mxnet_model-api)
3. [Checkpoint MXNet model using MXNetModelCheckpoint callback](#checkpoint-mxnet-model-using-mxnetmodelcheckpoint-callback)
4. [Train and Save a Convolutional Neural Network (CNN) Model for MNIST Dataset](#train-and-save-a-convolutional-neural-network-cnn-model-for-mnist-dataset)
5. [Import the model in MXNet for Inference](#import-the-model-in-mxnet-for-inference)
6. [What's next](#whats-next)
7. [References](#references)

## Objective

In this tutorial, we show how to train a model in Keras-MXNet, export the trained model as Apache MXNet model using `keras.models.save_mxnet_model()` API, and use MXNet natively for inference.
In this tutorial, we show how to train a model in Keras-MXNet, export the trained model as Apache MXNet model, and use MXNet natively for inference.

The Keras interface is known for its easy to use APIs, enabling fast prototyping in deep learning research.
MXNet is known for its high performance, production ready engine. With Keras-MXNet, you get an out-of-the-box API to
export most trained Keras models in MXNet model format.
You can now use Keras-MXNet for training the model and MXNet for inference in production. You can use `keras.models.save_mxnet_model()` API to save
the models trained on the CPU, a single GPU or multiple GPUs.
You can use one of the below APIs to export the native MXNet model from Keras-MXNet:
1. Using `keras.models.save_mxnet_model()` API.
2. Using `keras.callbacks.MXNetModelCheckpoint` callback.

The Keras interface is known for its easy to use APIs, enabling fast prototyping in deep learning research.
MXNet is known for its high performance, production ready engine. With Keras-MXNet, you get an out-of-the-box API to
export trained Keras models in MXNet model format.
You can now use Keras-MXNet for training the model and MXNet for inference in production. You can use `keras.models.save_mxnet_model()` API to save
the models trained on the CPU, a single GPU or multiple GPUs or, you can use `keras.callbacks.MXNetModelCheckpoint` callback to checkpoint the model.

You can use any language bindings supported by MXNet (Scala/Python/Julia/C++/R/Perl) for performing inference with these models!

`Warning` Not all Keras operators and functionalities are supported with MXNet backend. For more information, view the the list
of known issues and unsupported functionalities [here](https://github.com/awslabs/keras-apache-mxnet/issues/18).


## Export MXNet model using save_mxnet_model() API

This API accepts the following arguments:
* model: Keras model instance to be saved as MXNet model.
* prefix: Prefix name of the saved Model (symbol and params) files. Model will be saved as 'prefix-symbol.json' and 'prefix-epoch.params'.
* epoch: (Optional) Tag the params file with epoch of the model being saved. Default is 0. Model params file is saved as 'prefix-epoch.params' or 'prefix-0000.params' by default.


To summarize, all you have to do is to call the `keras.models.save_mxnet_model()` API by passing the trained Keras
To summarize, all you have to do is to call the `keras.models.save_mxnet_model()` API by passing the trained Keras
model to be exported in MXNet model format.

```python
Expand All @@ -38,18 +48,45 @@ model.fit(x_train, y_train,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))

# Save the trained Keras Model as MXNet Model
keras.models.save_mxnet_model(model=model, prefix='my_model')

# You get the MXNet model - (my_model-symbol.json, my_model-0000.params) in your current directory.
# Symbol and Params are two files representing a native MXNet model.

```


## Checkpoint MXNet model using MXNetModelCheckpoint callback

Using `MXNetModelCheckpoint` is similar to using Keras `ModelCheckpoint` callback. Please see here - https://keras.io/callbacks/#modelcheckpoint for callback parameters.
MXNetModelCheckpoint has following differences compared to ModelCheckpoint:
1. Accepts one additional parameter - `prefix` - name of the saved Model (symbol and params) files. Model will be saved as 'prefix-symbol.json' and 'prefix-epoch.params'.
2. Model checkpoint are always saved in the current working directory i.e., MXNetModelCheckpoint do not accept `filepath`. (This will be supported soon - [#131](https://github.com/awslabs/keras-apache-mxnet/issues/131))

You can use MXNetModelCheckpoint to checkpoint(save) the model as native MXNet model based on your checkpoint criteria - saving the best model, saving model after each epoch etc.

To obtain MXNet Model details for later binding to MXNet Module, you can use `K.get_mxnet_model_info()` API that returns `data_names` and `data_shapes` that you can use to bind the model with MXNet module.

```python
# ... Assuming you have built the model.

# Use MXNetModelCheckpoint callback to save best model during the training.
checkpoint = MXNetModelCheckpoint(mxnet_model_prefix="my_model", monitor = 'val_loss', verbose = 1, save_best_only = True, mode = 'min')
callbacks_list = [checkpoint]
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'], callbacks=callbacks_list)

# You get the best MXNet model - (my_model-symbol.json, my_model-0000.params) in your current directory.
# Symbol and Params are two files representing a native MXNet model.
```

NOTE: If you do not set `save_best_only = True` i.e., you want to save model after each epoch, there will be one symbol file - my_model-symbol.json and one params file for each epoch. Example: my_model-0001.params, my_model-0002.params

## Train and save a Convolutional Neural Network (CNN) model for MNIST dataset

We provide the following example for building a simple CNN model in Keras for [MNIST](http://yann.lecun.com/exdb/mnist/) handwritten digit recognition dataset. As you follow the example, you will save the model in the
We provide the following example for building a simple CNN model in Keras for [MNIST](http://yann.lecun.com/exdb/mnist/) handwritten digit recognition dataset. As you follow the example, you will save the model in the
MXNet model format. You will use the `keras.models.save_mxnet_model()` API.

```python
Expand Down Expand Up @@ -111,6 +148,15 @@ model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

''''
# You can also use MXNetModelCheckpoint callback.
checkpoint = MXNetModelCheckpoint(mxnet_model_prefix="my_model", monitor = 'val_loss', verbose = 1, save_best_only = True, mode = 'min')
callbacks_list = [checkpoint]
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'], callbacks=callbacks_list)
'''
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
Expand All @@ -124,6 +170,12 @@ score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])


'''
# If you have used MXNetModelCheckpoint callback, you can need a way to get data_names, data_shapes
data_names, data_shapes = K.get_mxnet_model_info(model)
'''

# Step 2: Save the model in MXNet model format.
# data_names and data_shapes are values of the parameters to be used when loading the Model in MXNet.
data_names, data_shapes = save_mxnet_model(model=model, prefix='mnist_cnn', epoch=0)
Expand Down Expand Up @@ -179,6 +231,11 @@ That's it! We trained a CNN model with Keras interface and used MXNet native eng
note that we can use any language binding supported by MXNet (Scala/Python/Julia/C++/R/Perl) for inference based on your
production environment setup and requirements.

## What's next
After you save a MXNet model, checkout our blog posts on how to use it with [MXNet Model Server](https://github.com/awslabs/mxnet-model-server) and [MXNet Scala Infer API](https://mxnet.incubator.apache.org/api/scala/infer.html)
1. [Deploy a Smile Detector with Keras MXNet and MXNet Model Server](https://medium.com/apache-mxnet/deploy-a-smile-detector-with-keras-mxnet-and-mxnet-model-server-48cd9741b6d2)
2. [Train using Keras MXNet and Inference using MXNet Scala API](https://medium.com/apache-mxnet/train-using-keras-mxnet-and-inference-using-mxnet-scala-api-49476a16a46a)

## References
1. [MXNet Module](https://mxnet.incubator.apache.org/api/python/module/module.html)
2. [MXNet Predicting with Pre-Trained Model](https://mxnet.incubator.apache.org/tutorials/python/predict_image.html)
Expand Down
Loading

0 comments on commit c0f4ef6

Please sign in to comment.