AWS Machine Learning Blog

Train and deploy deep learning models using JAX with Amazon SageMaker

Amazon SageMaker is a fully managed service that enables developers and data scientists to quickly and easily build, train, and deploy machine learning (ML) models at any scale. Typically, you can use the pre-built and optimized training and inference containers that have been optimized for AWS hardware. Although those containers cover many deep learning workloads, you may have use cases where you want to use a different framework or otherwise customize the contents of your OS libraries within the container. To accommodate this, SageMaker provides the flexibility to train models using any framework that can run in a Docker container. This functionality enables you to use existing SageMaker training capabilities such as training jobs, hyperparameter tuning, and Managed Spot Training.

Although we primarily cover training jobs in this post, it’s useful to remember that Spot Training can offer savings of up to 90% in comparison to On-Demand Instances. Spot Training can be enabled by switching a keyword input in the SageMaker training job code. Similarly, SageMaker hyperparameter tuning can alleviate the undifferentiated heavy lifting of maintaining an MLOps pipeline that can perform hyperparameter tuning for ML models.

In this post, we show how to utilize the Bring Your Own Container (BYOC) paradigm to train ML models on GPUs using the increasingly popular JAX library from Google. As a bonus, we serialize our trained model into the TensorFlow SavedModel format so that we can use the existing TensorFlow Serving infrastructure provided by SageMaker.

The scripts and notebooks used in this post are available in our GitHub repository.

Overview of solution

JAX is an increasingly popular deep-learning framework that enables composable function transformations of native Python or NumPy functions. You can use the transformations for a combination of automatic differentiation as well as acceleration. Many native Python and NumPy functions are available within the automatic differentiation framework. When JAX programs are run, the programs are compiled using XLA to then be consumed by GPUs and other accelerators. This means that JAX provides the ability to write NumPy programs that can be automatically differentiated and accelerated using GPUs, resulting in a more flexible framework to support modern deep learning architectures.

In this solution, we use a custom container to train three different neural networks on SageMaker. The first is a standard JAX model, the second uses a submodule within JAX called stax, and the third uses a higher-level library called Trax. This is possible on a single container because we use the sagemaker-training-toolkit, which allows you to utilize script mode within your own custom containers. The custom container can use built-in SageMaker training jobs features like Spot Training and hyperparameter tuning.

After you train the model, you can deploy your trained models to managed endpoints. As previously mentioned, SageMaker has inference containers that have optimized versions of popular frameworks for AWS hardware. One of these optimizations is for the TensorFlow framework. Because JAX supports model export into TensorFlow SavedModel format, we use that functionality to show how to deploy trained models on optimized SageMaker TensorFlow inference endpoints.

The following walkthrough is also outlined in the Jupyter notebook corresponding to this post. The steps are as follows:

  1. Create a Docker image and push it to Amazon Elastic Container Registry (Amazon ECR).
  2. Create a custom framework estimator using the SageMaker SDK in order to classify model outputs as a TensorFlowModel.
  3. The repository has scripts to train estimators using three different abstractions, but in this post we use the Trax convolutional neural network example.
  4. Train each of the models using SageMaker training jobs on GPUs.
  5. Deploy the model to a fully managed endpoint.

After you complete Steps 1 and 2, you can complete Steps 3–5 with just a few lines of code.

Create a custom Docker container

To train models using JAX and SageMaker, we first create a Docker image that contains the necessary Python packages for model training. We do this using a Dockerfile with its content as follows:

# Dockerfile for training models using JAX
# We build from NVIDIA container so that CUDA is available for GPU acceleration should the AWS instance support it
FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu18.04

# Install python3
RUN apt update && apt install -y python3-pip

RUN ln -sf /usr/bin/python3 /usr/bin/python && \
ln -sf /usr/bin/pip3 /usr/bin/pip

RUN pip --no-cache-dir install --upgrade pip setuptools_rust

# Install ML Packages built with CUDA11 support
RUN ln -s /usr/lib/cuda /usr/local/cuda-11.1
RUN pip --no-cache-dir install --upgrade jax==0.2.6 jaxlib==0.1.57+cuda111 -f
RUN pip --no-cache-dir install tensorflow==2.3.1 trax==1.3.7
RUN pip --no-cache-dir install sagemaker-training matplotlib

# Setting some environment variables related to logging

The Docker image is built on top of a CUDA-enabled container provided by NVIDIA. To ensure that the jaxlibpackage that underlies the functionality in JAX is CUDA-enabled, the jaxlib package is downloaded from the jax_releases repository. We build and push this image from a SageMaker notebook instance to Amazon ECR. The code to do this is provided in this notebook. A Docker container created using a similar process such as this can be consumed by SageMaker training jobs regardless of the language. Although in this example, we use Python end to end, you can submit a training job from the AWS Command Line Interface (AWS CLI), which uses a custom Docker container.

Create a custom framework estimator

As a convenience, we create a subclass of the base SageMaker framework estimator to specify the model type of our estimator as a TensorFlow model. To do this, we specify a custom create_model method that uses the existing TensorFlowModel class to launch inference containers. The code snippet is as follows:

Custom Framework Estimator for JAX
from sagemaker.estimator import Framework
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

class JaxEstimator(Framework):
	def __init__(
		self, entry_point, source_dir=None, hyperparameters=None, image_uri=None, **kwargs
		super(JaxEstimator, self).__init__(
			entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs

	def create_model(
		"""Creates ``TensorFlowModel`` object to be used for creating SageMaker model entities"""
		kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

		if "enable_network_isolation" not in kwargs:
			kwargs["enable_network_isolation"] = self.enable_network_isolation()

		return TensorFlowModel(
			role=role or self.role,

Train script modifications to enable deployments to managed endpoints

We utilize the trax.AsKeras method to export our model in the required SavedModel format (see the following code). For JAX and stax, you can call a jax2tf function to perform the same operation. Code is available in the repository. It’s important to set the correct path, /opt/ml/model/1, which is where the SageMaker wrapper assumes the model has been stored.

def save_model_tf(model_to_save):
	Serialize a TensorFlow graph from trained Trax Model
	:param model_to_save: Trax Model
	keras_layer = trax.AsKeras(model_to_save, batch_size=1)
	inputs = tf.keras.Input(shape=(28, 28, 1))
	hidden = keras_layer(inputs)

	keras_model = tf.keras.Model(inputs=inputs, outputs=hidden)"/opt/ml/model/1", save_format="tf")

Train and deploy the model, and perform inference

In the previous sections, we discussed the three primary components to enabling JAX training jobs and deployments using existing SageMaker functionality. After you implement these, you can perform training, deployment, and running inference through the model by a conventional SageMaker Python SDK workflow. We make sure to import and initialize the JaxEstimator that was defined in the code snippet for the custom framework estimator, and then run the standard .fit() and .deploy() calls.

Modifying the Estimator instantiation

Creating a JaxEstimator for use with Amazon SageMaker Training Jobs

Deploying the estimator

Deploying the JaxEstimator object

Finally, we query the endpoint to verify the results of the training and deployment.

Querying the model

Querying the deployed model

You can review the status of the training jobs and endpoints on the SageMaker console via the appropriate Region, or acquire the information programmatically using the AWS CLI or other tools.

Clean up

As a last step, we recommend deleting your endpoints if you no longer need them.

Clean up

Deleting the deployed JaxEstimator endpoints


In this post, we showed how to integrate JAX with SageMaker by creating a custom framework estimator. We also showed how to train a model using the high-level Trax API to implement neural-networks trained on the Fashion MNIST dataset. We took advantage of the fact that these models can be saved into the SavedModel format to deploy them to managed SageMakerTensorFlow endpoints.

As a call to action, we want you to run the notebook here and start building your own JAX-based neural networks today. We encourage you to utilize SageMaker for JAX model training and hosting.

About the Authors

Archis Joglekar is an AI/ML Partner Solutions Architect in the Emerging Technologies team. He is interested in performant, scalable deep learning and scientific computing using the building blocks at AWS. His past experiences range from computational physics research to machine learning platform development in academia, national labs, and startups. His time away from the computer is spent playing soccer and with friends and family.

Sean MorganSean Morgan is an AI/ML Solutions Architect at AWS. He has experience in the semiconductor and academic research fields, and uses his experience to help customers reach their goals on AWS. In his free time Sean is an activate open source contributor/maintainer and is the special interest group lead for TensorFlow Addons.