AWS Open Source Blog
Leverage deep learning in Scala with GPU on Spark 3.0
This post was contributed by Qing Lan, Carol McDonald, and Kong Zhao.
With the growing interest in deep learning (DL), more users are using DL in their production environments. Because DL requires intensive computational power, developers are leveraging GPUs to do their training and inference jobs.
As part of a major Apache Spark initiative to better unify DL and data processing on Spark, GPUs are now a schedulable resource in Apache Spark 3.0. Spark conveys these resource requests to the underlying cluster manager. Because this functionality allows you to run distributed inference at scale, it could help accelerate big data pipelines to leverage DL applications.
Before Apache Spark 3.0, using GPUs was difficult. Users had to manually assign GPU devices to a Spark job and hardcode all configurations for every executor/task to leverage different GPUs on a single machine. Because the Apache Hadoop 3.1 Yarn cluster manager allows GPU coordination among different machines, Apache Spark can now work alongside it to help pass the device arrangement to different tasks. Users can simply specify the number of GPUs to use and how those GPUs should be shared between tasks. Spark handles the assignment and coordination of the tasks.
In this tutorial, we demonstrate how to create a cluster of GPU machines and use Apache Spark with Deep Java Library (DJL) on Amazon EMR to leverage large-scale image classification in Scala. DJL now provides a GPU-based Deep Learning Java package that is designed to work smoothly in Spark.
DJL provides a viable solution for users who are interested in Scala/Java or are looking for a solution to integrate DL into their Scala-based big data pipeline. DJL aims to make deep-learning open source tools accessible to developers/data engineers who use primarily Java/Scala by using familiar concepts and intuitive APIs. DJL is built on top of modern DL frameworks (for example, TensorFlow, PyTorch, Apache MXNet, among others). You can easily use DJL to train your model or deploy a model trained using Python from a variety of engines without any additional conversion.
Prepare Spark application
Setup
For full setup information, refer to the Gradle project setup. The following section highlights some key components you need to know.
First, we’ll import the Spark dependencies. Spark SQL and ML libraries are used to store and process the images. The Spark dependencies are only used at compile time and are excluded in packaging because they are provided during runtime. The jar task excludes them when everything is packaged.
Next, we import the DJL-related dependencies. We use DJL API and PyTorch packages, which provide the core DJL features and load a DL engine to run for inference. We also leverage pytorch-native-cu101
to run on GPU with CUDA 10.1.
Load model
To load a model in DJL, we provide a URL (for example, file://, hdfs://, s3://, https://) hosting the model. The model will be downloaded and imported from that URL. DJL also offers a powerful ModelZoo. The ModelZoo allows you to manage pretrained models and load them in a single line. The built-in ModelZoo currently supports more than 70 pretrained and ready-to-use models from GluonCV, HuggingFace, TorchHub, and Keras.
The input type here is a Row
in Spark SQL. The output type is a Classification result. We also defined a Translator named MyTranslator
that deals with preprocessing and post-processing work. The model we load here is a pretrained PyTorch ResNet18 model from torchvision.
Main logic
In the main function, we download images (downloadImages
) and store them into the hdfs
. After that, we can create a SparkSession
and use the built-in Spark image loading mechanism to load all images into Spark SQL. After this step, we use mapPartition
to fetch the GPU information.
As shown in the following, TaskContext.resources()("gpu")
stores the assigned GPU for this partition. We can assign the GPU ID to the model to load the model on that particular GPU. This step will ensure all GPUs on a single device are properly used. To run inference, run predictor.predict(row)
.
Wrap it up
Next, we run ./gradlew jar
to bundle everything we need into a single jar and run it in a Spark cluster.
Set up Spark Cluster with GPU
Since the release of EMR 6.2.0, Spark 3.0 is available in all GPU instances.
To set up a Spark cluster, create a GPU cluster with three instances using AWS CLI. To run the command successfully, you’ll need to change myKey
to your Amazon Elastic Compute Cloud (Amazon EC2) pem key name. The --region
flag can also be removed if you have that preconfigured in your AWS CLI.
We use the g3s.xlarge
instance type for testing purposes. You can choose from a variety of GPU instances that are available in AWS. The total run time for the cluster setup is around 10 to 15 minutes.
Execute the Spark job
Now, we can run the distributed inference job on Spark. You can choose to do it on the EMR console or from the command line.
The following command tells Spark to run a Yarn cluster and setup script to find GPUs on different devices. The GPU amount per task is set to 0.5, which means that two tasks share one GPU. You may also need to set CPU number accordingly to ensure they match. For example, if you have an 8-core CPU and you set spark.task.cpus
to 2, it means that four tasks can run in parallel on a single machine.
To achieve the best performance, you can set spark.task.resource.gpu.amount
to 0.25, which allows four tasks to share the same GPU. This helps to maximize the performance because all cores in the GPU and CPU are used. Without a balanced setup, some cores will be in an idle state, which wastes resources.
This script takes around 4 to 6 minutes to finish, and you will get a printout inference result as output.
Summary
In this tutorial, we built the package from scratch and submitted the work to a GPU cluster for inference tasks. You can try out the same setup for your own application. If you are interested in more features that DJL provides, follow our GitHub, demo repository, Slack channel, and Twitter for documentation and examples of DJL.