AWS Machine Learning Blog

Model Server for Apache MXNet adds support for serving Gluon models

Today AWS released Model Server for Apache MXNet (MMS) v0.4, which adds support for serving Gluon models. Gluon is an imperative and dynamic interface for MXNet, which enables rapid model development, while maintaining MXNet performance. With this release, MMS adds support for packaging and serving Gluon models at scale. In this blog post, we will describe the v0.4 release in detail and go over an example for serving a Gluon model.

What is Model Server for Apache MXNet (MMS)?

MMS is an open source model-serving framework, designed to simplify the task of serving deep learning models at scale. Here are some key advantages of MMS:

  • Provides a packaging tool to generate a model archive containing the neural network model artifacts needed to serve MXNet, Gluon, and ONNX neural network models.
  • Gives you the ability to customize every step in the inference execution pipeline using custom code packaged into the model archive, which enables overriding initialization, pre-processing, and post-processing.
  • Comes with a preconfigured serving stack, including REST API endpoints, and an inference engine.
  • Provides prebuilt and optimized container images for scalable model serving.
  • Includes real-time operational metrics to monitor the server and its endpoints.

What is Gluon?

Gluon is a clear, concise, and simple Python interface to MXNet. It enables engineers to write imperative code to build neural networks, without losing the performance benefits of a symbolic implementation. Gluon is capable of automatically generating optimized symbolic code, based on imperative implementation.

Serving a Gluon model in MMS

Let’s go over an example of serving a Gluon model with MMS. In this example, we’ll use Gluon to implement a model based on Xiang Zhang’s Character-level Convolutional Neural Network (char-CNN). We’ll package the model and serve it with MMS. A char-CNN involves using CNN instead of a Recurrent Neural Network (RNN) to understand vocabulary and classify text. This enables classification, without the need to understand language semantics or syntax. When using this methodology, data errors such as spelling mistakes have a lesser impact on prediction performance.

Figure 1. Illustration of char-CNN model (source)

This image shows a basic char-CNN model structure. Words are transformed into a vocabulary tensor, then fed into the convolutional layers. The model is called the Crepe model. In the following GluonCrepe code, we use the Crepe model to predict a product category, based on text from a product review. The dataset the model is trained on is the Amazon product dataset. (For details on how the model is trained, please follow this detailed tutorial provided by Thomas Delteil.)

The model definition is as follows.

class GluonCrepe(HybridBlock):
  """
  Hybrid Block gluon Crepe model
  """
  def __init__(self, classes=7, **kwargs):
      super(GluonCrepe, self).__init__(**kwargs)
      self.NUM_FILTERS = 256 # number of convolutional filters per convolutional layer
      self.NUM_OUTPUTS = classes # number of classes
      self.FULLY_CONNECTED = 1024 # number of unit in the fully connected dense layer
      self.features = nn.HybridSequential()
      with self.name_scope():
          self.features.add(
              nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=7, activation='relu'),
              nn.MaxPool1D(pool_size=3, strides=3),
              nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=7, activation='relu'),
              nn.MaxPool1D(pool_size=3, strides=3),
              nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=3, activation='relu'),
              nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=3, activation='relu'),
              nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=3, activation='relu'),
              nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=3, activation='relu'),
              nn.MaxPool1D(pool_size=3, strides=3),
              nn.Flatten(),
              nn.Dense(self.FULLY_CONNECTED, activation='relu'),
              nn.Dense(self.FULLY_CONNECTED, activation='relu'),
          )
          self.output = nn.Dense(self.NUM_OUTPUTS)

To run a model as a service on MMS, we require the following files:

  1. The custom service Python file – Defines the pre-process, inference, and post-process steps. It is also used to load pre-trained weights.
  2. The signature JSON file – Defines the expected input and output shapes of the model.
  3. The model file – This is a JSON file for regular MXNet model files. However, for Gluon models, they are class definitions placed within the service Python file.
  4. The parameters file – Stores the pre-trained weights of the model.
  5. The synset text file – Unique to classification models, this file contains the labels for the output classes.

Trained model weights are required to run inference on the model. Download the pre-trained weights. MMS uses custom service class derivations to define pre-process, inference, and post-process functionality. The custom services are specific to the model. For this example, Crepe model, we define the service as follows.

class CharacterCNNService(GluonImperativeBaseService):
  """
  Gluon Character-level Convolution Service
  """
  def __init__(self, model_name, model_dir, manifest, gpu=None):
      net = GluonCrepe()
      super(CharacterCNNService, self).__init__(model_name, model_dir, manifest,net, gpu)
      # The 69 characters as specified in the paper
      self.ALPHABET = list("abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}")
      # Map Alphabets to index
      self.ALPHABET_INDEX = {letter: index for index, letter in enumerate(self.ALPHABET)}
      # max-length in characters for one document
      self.FEATURE_LEN = 1014
      # Hybridize imperative model for best performance
      self.net.hybridize()
	
  def _preprocess(self, data):
      # build the text from the request
      text = '{}|{}'.format(data[0][0]['review_title'], data[0][0]['review'])
	
      encoded = np.zeros([len(self.ALPHABET), self.FEATURE_LEN], dtype='float32')
      review = text.lower()[:self.FEATURE_LEN-1:-1]
      i = 0
      for letter in text:
          if i >= self.FEATURE_LEN:
              break;
          if letter in self.ALPHABET_INDEX:
              encoded[self.ALPHABET_INDEX[letter]][i] = 1
          i += 1
      return nd.array([encoded], ctx=self.ctx)
	
  def _inference(self, data):
      # Call forward/hybrid_forward
      output = self.net(data)
      return output.softmax()
	
  def _postprocess(self, data):
      # Post process and output the most likely category
      predicted = self.labels[np.argmax(data[0].asnumpy())]
      return [{'category': predicted}]

In the preprocess section the input text is truncated and encoded. The encoding is done based on 69 characters defined in the paper.

Next, we define the signature JSON file to inform MMS about the input and output shapes and their data types.

"inputs": [{  
    "data_name": "data",  
    "data_shape": [1, 1014]  
}], "input_type": "application/json", "outputs": [{  
    "data_name": "softmax",  
    "data_shape": [0, 7]  
}], "output_type": "application/json"  
}  

Both the inputs and outputs are text and are passed on in a JSON format. The input size of [1,1014] is a recommendation from the paper and the Amazon product dataset has seven output classes. The synset text file, stores the seven output classes of Amazon product dataset.

All these files can be packaged to a single archive, using the MMS export tool. A complete custom service file artifact is needed to build the archive. Install MMS and place the model artifacts in the same folder.

$ mxnet-model-export --model-name="character_cnn" --model-path="/path/to/model/folder" --service-file-path="/path/to/model/folder/service_file.py" 

A pre-packaged model is available for download.

You can now serve the model, with following commands:

$ pip install mxnet-model-server
$ mxnet-model-server --models crepe=character_cnn.model

These commands will start serving the model at http://127.0.0.1:8080. For further details on MMS server and export command line options, refer to the MMS documentation.

To get a prediction from the server, open a new terminal window and run the following command:

$ curl -X POST http://127.0.0.1:8080/crepe/predict -F "data=[{'review_title':'Inception is the best','review': 'great direction and story'}]"

The output from the server will be as follows.

{"prediction":[{"category":"Movies_and_TV"}]}

Let’s look at another example.

$ curl -X POST http://127.0.0.1:8080/crepe/predict -F "data=[{'review_title':'fantastic quality','review': 'quality sound playback'}]"

The response for this request follows.

{"prediction":[{"category":"CDs_and_Vinyl"}]}

For scalable production use cases, we recommend using containers. You can build MMS containers from source, and you can pull the image you need from MMS DockerHub repository .

Learn more and contribute

To learn more about MMS, start with our Single Shot Multi Object Detection (SSD) tutorial, which walks you through exporting and serving an SSD model. You can find more examples and documentation in the repository’s model zoo and documentation folder.

As we continue to develop MMS, we welcome community participation submitted as questions, requests, and contributions. If you are using MMS already, we welcome your feedback via the repository’s GitHub issues. Head over to awslabs/mxnet-model-server to get started!

Citations

Ups and downs: Modeling the visual evolution of fashion trends with one-class collaborative filtering
R. He, J. McAuley
WWW, 2016
pdf


About the Authors

Rakesh Vasudevan is a Software Development Engineer with AWS Deep Learning. He is passionate about  building scalable deep learning systems. In spare time, he enjoys gaming, cricket and hanging out with friends and family.

 

 

 

Vamshidhar Dantu is a Software Developer with AWS Deep Learning. He focuses on building scalable and easily deployable deep learning systems. In his spare time, he enjoy spending time with family and playing badminton.

 

 

 

Hagay Lupesko is an Engineering Leader for AWS Deep Learning. He focuses on building deep learning systems that enable developers and scientists to build intelligent applications. In his spare time, he enjoys reading, hiking, and spending time with his family.