AWS Machine Learning Blog
Train and deploy deep learning models using JAX with Amazon SageMaker
Amazon SageMaker is a fully managed service that enables developers and data scientists to quickly and easily build, train, and deploy machine learning (ML) models at any scale. Typically, you can use the pre-built and optimized training and inference containers that have been optimized for AWS hardware. Although those containers cover many deep learning workloads, you may have use cases where you want to use a different framework or otherwise customize the contents of your OS libraries within the container. To accommodate this, SageMaker provides the flexibility to train models using any framework that can run in a Docker container. This functionality enables you to use existing SageMaker training capabilities such as training jobs, hyperparameter tuning, and Managed Spot Training.
In this post, we show how to utilize the Bring Your Own Container (BYOC) paradigm to train ML models on GPUs using the increasingly popular JAX library from Google. As a bonus, we serialize our trained model into the TensorFlow SavedModel format so that we can use the existing TensorFlow Serving infrastructure provided by SageMaker.
The scripts and notebooks used in this post are available in our GitHub repository.
Overview of solution
JAX is an increasingly popular deep-learning framework that enables composable function transformations of native Python or NumPy functions. You can use the transformations for a combination of automatic differentiation as well as acceleration. Many native Python and NumPy functions are available within the automatic differentiation framework. When JAX programs are run, the programs are compiled using XLA to then be consumed by GPUs and other accelerators. This means that JAX provides the ability to write NumPy programs that can be automatically differentiated and accelerated using GPUs, resulting in a more flexible framework to support modern deep learning architectures.
In this solution, we use a custom container to train three different neural networks on SageMaker. The first is a standard JAX model, the second uses a submodule within JAX called stax, and the third uses a higher-level library called Trax. This is possible on a single container because we use the sagemaker-training-toolkit, which allows you to utilize script mode within your own custom containers. The custom container can use built-in SageMaker training jobs features like Spot Training and hyperparameter tuning.
After you train the model, you can deploy your trained models to managed endpoints. As previously mentioned, SageMaker has inference containers that have optimized versions of popular frameworks for AWS hardware. One of these optimizations is for the TensorFlow framework. Because JAX supports model export into TensorFlow SavedModel format, we use that functionality to show how to deploy trained models on optimized SageMaker TensorFlow inference endpoints.
The following walkthrough is also outlined in the Jupyter notebook corresponding to this post. The steps are as follows:
- Create a Docker image and push it to Amazon Elastic Container Registry (Amazon ECR).
- Create a custom framework estimator using the SageMaker SDK in order to classify model outputs as a TensorFlowModel.
- The repository has scripts to train estimators using three different abstractions, but in this post we use the Trax convolutional neural network example.
- Train each of the models using SageMaker training jobs on GPUs.
- Deploy the model to a fully managed endpoint.
After you complete Steps 1 and 2, you can complete Steps 3–5 with just a few lines of code.
Create a custom Docker container
To train models using JAX and SageMaker, we first create a Docker image that contains the necessary Python packages for model training. We do this using a Dockerfile with its content as follows:
The Docker image is built on top of a CUDA-enabled container provided by NVIDIA. To ensure that the jaxlibpackage that underlies the functionality in JAX is CUDA-enabled, the jaxlib package is downloaded from the jax_releases repository. We build and push this image from a SageMaker notebook instance to Amazon ECR. The code to do this is provided in this notebook. A Docker container created using a similar process such as this can be consumed by SageMaker training jobs regardless of the language. Although in this example, we use Python end to end, you can submit a training job from the AWS Command Line Interface (AWS CLI), which uses a custom Docker container.
Create a custom framework estimator
As a convenience, we create a subclass of the base SageMaker framework estimator to specify the model type of our estimator as a TensorFlow model. To do this, we specify a custom create_model
method that uses the existing TensorFlowModel class to launch inference containers. The code snippet is as follows:
Train script modifications to enable deployments to managed endpoints
We utilize the trax.AsKeras method to export our model in the required SavedModel format (see the following code). For JAX and stax, you can call a jax2tf
function to perform the same operation. Code is available in the repository. It’s important to set the correct path, /opt/ml/model/1
, which is where the SageMaker wrapper assumes the model has been stored.
Train and deploy the model, and perform inference
In the previous sections, we discussed the three primary components to enabling JAX training jobs and deployments using existing SageMaker functionality. After you implement these, you can perform training, deployment, and running inference through the model by a conventional SageMaker Python SDK workflow. We make sure to import and initialize the JaxEstimator that was defined in the code snippet for the custom framework estimator, and then run the standard .fit()
and .deploy()
calls.
Finally, we query the endpoint to verify the results of the training and deployment.
You can review the status of the training jobs and endpoints on the SageMaker console via the appropriate Region, or acquire the information programmatically using the AWS CLI or other tools.
Clean up
As a last step, we recommend deleting your endpoints if you no longer need them.
Conclusion
In this post, we showed how to integrate JAX with SageMaker by creating a custom framework estimator. We also showed how to train a model using the high-level Trax API to implement neural-networks trained on the Fashion MNIST dataset. We took advantage of the fact that these models can be saved into the SavedModel
format to deploy them to managed SageMakerTensorFlow
endpoints.
As a call to action, we want you to run the notebook here and start building your own JAX-based neural networks today. We encourage you to utilize SageMaker for JAX model training and hosting.
About the Authors
Archis Joglekar is an AI/ML Partner Solutions Architect in the Emerging Technologies team. He is interested in performant, scalable deep learning and scientific computing using the building blocks at AWS. His past experiences range from computational physics research to machine learning platform development in academia, national labs, and startups. His time away from the computer is spent playing soccer and with friends and family.
Sean Morgan is an AI/ML Solutions Architect at AWS. He has experience in the semiconductor and academic research fields, and uses his experience to help customers reach their goals on AWS. In his free time Sean is an activate open source contributor/maintainer and is the special interest group lead for TensorFlow Addons.