AWS Open Source Blog

How Amazon retail systems run machine learning predictions with Apache Spark using Deep Java Library

Today more and more companies are taking a personalized approach to content and marketing. For example, retailers are personalizing product recommendations and promotions for customers. An important step toward providing personalized recommendations is to identify a customer’s propensity to take action for a certain category. This propensity is based on a customer’s preferences and past behaviors, and it can be used to personalize marketing (e.g., more relevant email campaigns, ads, and website banners).

At Amazon, the retail systems team created a multi-label classification model in MXNet to understand customer action propensity across thousands of product categories, and we use these propensities to create a personalized experience for our customers. In this post, we will describe the key challenges we faced while building these propensity models and how we solved them at the Amazon scale with Apache Spark using the Deep Java Library (DJL). DJL is an open source library to build and deploy deep learning in Java.


A key challenge was building a production system that can grow to Amazon-scale and is easy to maintain. We found that Apache Spark helped us scale within the desired runtime. For the machine learning (ML) framework for building our models, we found that MXNet scales to fulfill our data requirement for hundreds of millions of records and gave us better execution time and model accuracy compared to other available machine learning frameworks.

Our team consists of a mix of software development engineers and research scientists. Our engineering team wanted to build a production system using Apache Spark in Java/Scala, whereas scientists preferred to use Python frameworks. This posed another challenge while deciding between Java and Python-based systems. We looked for ways where both teams could work together in their preferred programming language and found that we could use DJL with MXNet to solve this problem. Now, scientists build models using the MXNet – Python API and share their model artifacts with the engineering team. The engineering team uses DJL to run inference on the model provided using Apache Spark with Scala. Since DJL is machine learning framework-agnostic, the engineering team doesn’t need to make code changes in the future if the scientists want to migrate their model to a different ML framework (e.g., PyTorch or TensorFlow).


To train the classification model, we need two sets of data: features and labels.

Feature data

To build any machine learning model, one of the most important inputs is the feature data. One benefit of using multi-label classification is that we can have a single pipeline to generate feature data. This pipeline captures signals from multiple categories and uses that single dataset to find customer propensity for each category. This reduces operational overhead because we only need to maintain a single multi-label classification model rather than multiple binary classification models.

For our multi-label classification, we generated high-dimensional feature data. We created hundreds of thousands of features per customer for hundreds of millions of customers. These customer features are sparse in nature and can be represented in sparse vector representation:


We created hundreds of thousands of features per customer for hundreds of millions of customers. These customer features are sparse in nature and can be represented in Sparse Vector representation.

Label data

A propensity model predicts the likelihood of a given customer taking action in a particular category. For each region, we have thousands of categories that we want to generate customer propensities for. Each label has a binary value: 1 if the customer made the required action in a given category, 0 otherwise. These labels of past behavior are used to predict the propensity of a customer taking the same action in a given category in the future. The following is an example of the initial label represented as the one-hot encoding for four categories:

example of the initial label represented as one-hot encoding for 4 categories

In this example, customer A only took actions in category 1 and category 3 in the past, whereas customer B only took actions in category 2.


Model architecture

The propensity model is implemented in MXNet using the Python API, is a feed-forward network consisting of a sparse input layer, hidden layers, and N output layers where N is the number of categories we are interested in. Although the output layers can be easily represented by logistics regression output, we chose to implement the network using softmax output to allow flexibility in training models with more than two classes. The following is a high-level diagram of a network with four target output:

high-level diagram of a network with four target output

Below is the pseudocode for the network architecture:

data <- variable of stype 'csr'
weight <- variable of stype 'row_sparse'
bias <- variable of with length equals to number of first hidden node
first_hidden_layer <- broadcast.add between (data,weight) and bias
hidden_layers <- subsequent hidden layers, activation and dropout
classification_layer <- N FullyConnected layer, each with two nodes
output_layer <- SoftmaxOutput layer

Model training

To train the model, we wrote a custom iterator to process the sparse data and convert it to MXNet arrays. In each iteration, we read in a batch of data consisting of customerIds, labels, and sparse features. We then constructed a sparse MXNet CSR matrix to encode the features by specifying the non-zero values, non-zero indices, index pointers as well as the shape of the CSR matrix. In the following example, we construct the sparse MXNet CSR matrix with batch size = 3 and feature size = 5.

example, we construct the sparse MXNet CSR matrix with batch size = 3 and feature size = 5

The label feeding into the MXNet module is a list of MXNet NDArray. Each element in the list represents a target category. Thus the i’th element in the label list represents the training labels of the batch for category i. This is a 2-D array where the first dimension is the label for product category i and the second dimension is the complement of that label. The following is an example with batch size = 3 and number of categories = 4.

example with batch size = 3 and number of categories = 4

We then passed the features and labels as an MXNet DataBatch to be used in training. We used the multi-label log-loss metric to train the neural network.

Inference and performance

As mentioned previously, model training was done using Apache MXNet Python APIs while inference is done in Apache Spark with Scala as the programming language. Because DJL provides Java APIs, it can be easily integrated into a Scala application.


To include DJL libraries into the project, we included below DJL dependencies.

dependencies {
    compile group: 'ai.djl', name: 'repository', version: '0.4.1'
    compile group: 'ai.djl.mxnet', name: 'mxnet-engine', version: '0.4.1'
    runtime group: 'ai.djl.mxnet', name: 'mxnet-native-mkl', version: '1.6.0'


DJL internally works on NDList and provides a Translator interface to convert the custom input data type to NDList; it also converts output NDList to the custom output data type. DJL supports sparse data in the form of CSR data and allows scoring a batch of data.

First, we loaded the model artifacts.

val modelDir: Path = Paths.get("/Your/Model/Directory")
val modelName: String = "your_model_name"
val model: Model = Model.newInstance()
model.load(modelDir, modelName)

We defined Translator to convert the input feature vector to NDList containing CSR data and convert output predictions of type NDList to Array[Array[Float]].

class InputOutputTranslator 
  extends Translator[Array[SparseVector], Array[Array[Float]]] {

  override def processInput(translatorContext: TranslatorContext, 
                           input: Array[SparseVector]): NDList = {
  // convert Array[SparseVector] to CSR NDArray
  val indices: Array[Long] = ...
  val indptr: Array[Long] = ....
  val data: Buffer = ....
  val shape: Shape = ....
  val csrFeatures: NDArray = model.getNDManager.createCSR(data, indptr, indices, shape)
  new NDList(csrFeatures)
 override def processOutput(translatorContext: TranslatorContext, 
                            predictionNDList: NDList): Array[Array[Float]] = {
 // We are doing batch prediction for multi-label classification model, hence output 
 // type Array[Array[Float]] represent prediction Array[Float] for each record. Each 
 // Array[Float] represent prediction for each label

Above Translator is used to define Predictor object, which is used to generate predictions.

val predictor: Predictor[Array[SparseVector], Array[Array[Float]]] = 
       model.newPredictor(new InputOutputTranslator)
val featureVectorArray : Array[SparseVector] = ...  
val predictions: Array[Array[Float]] = predictor.predict(featureVectorArray)

Final data was generated by combining the above predictions with the category names and customerId.

  "customerId": 1
  "predictions": [
    "category_a": 0.813611214,
    "category_b": 0.580259696,
    "category_c": 7.5886305E-4,
    "category_d": 0.7010947181,
  "customerId": 2
  "predictions": [
    "category_a": 0.0066125533,
    "category_b": 0.304356237,
    "category_c": 0.908850298,
    "category_d": 2.3412544E-6,


Before DJL, running predictions with this model and such high-dimensional data used to take around 24 hours and had multiple memory issues. DJL reduced the prediction time on this model by 85%, from around one day to a couple of hours. DJL worked out of the box without spending any time on engineering tasks, such as memory tuning. In contrast, prior to DJL, we spent more than two weeks in memory tuning.

More about DJL

Deep Java Library (DJL) is an open source library to build and deploy deep learning in Java. This project launched in December 2019 and is widely used among teams at Amazon. This effort was inspired by other DL frameworks, but was developed from the ground up to better suit Java development practices. DJL is framework agnostic, with support for Apache MXNet, PyTorch, TensorFlow 2.x (experimental), and fastText (experimental). Additionally, DJL offers a repository of pre-trained models in our ModelZoo that simplifies implementation and streamlines model sharing across projects.

Key advantages of using DJL

Ease of integration and deployment. With DJL, you integrate ML in your applications natively in Java. Because DJL runs in the same JVM process as other Java applications, you don’t need to manage (or pay for) a separate model serving service or container. We have customers who have integrated DJL easily into existing Spark applications written in Scala, eliminating the need to write an additional Scala wrapper on top of a deep learning framework.

Highly performant. DJL offers microseconds of latency by eliminating the need for a gPRC or web service calls. DJL also leverages multi-threading in inference to further improve latency and throughput. Users can leverage DJL with Spark for large scale DL applications.

Framework Agnostic. DJL provides unified and Java-friendly API regardless of the frameworks you use—MXNet, TensorFlow, or PyTorch. True to its Java roots, you can write your code once in DJL and run it with a framework of your choice. You also have the flexibility to access low-level framework specific features.

To learn more about DJL, check the website, Github repository, and Slack channel.

Feature image via Pixabay.

Vaibhav Goel

Vaibhav Goel

Vaibhav Goel is a Sr. Software Development Engineer on the Amazon Customer Behavior Analytics team. He has a background in large scale optimization in both 1) Low-latency distributed services and 2) Big data processing systems. He also has experience in Machine Learning to provide a personalized shopping experience to customers.

Raja Hafiz Affandi

Raja Hafiz Affandi

Raja Hafiz Affandi is a Research Scientist in Customer Behavior Analytics (CBA) at Amazon where he focuses on predictive targeting, customer analytics and segmentation. Prior to his current stint in CBA, Raja has worked in various organizations within Amazon focusing on predictive modeling and optimization in advertising, as well as time series forecasting. Raja received his PhD from University of Pennsylvania in the area of probabilistic and statistical machine learning where he built his expertise in modelling diversity using Determinantal Point Processes with applications in ML and Stats such as diverse recommender system, sparse clustering, and automated text and video summarization.