AWS Machine Learning Blog

Reducing training time with Apache MXNet and Horovod on Amazon SageMaker

Amazon SageMaker is a fully managed service that provides every developer and data scientist with the ability to build, train, and deploy machine learning (ML) models quickly. Amazon SageMaker removes the heavy lifting from each step of the ML process to make it easier to develop high-quality models. As datasets continue to increase in size, additional compute is required to reduce the amount of time it takes to train. One method to scale horizontally and add these additional resources on Amazon SageMaker is through the use of Horovod and Apache MXNet. In this post, we show how you can reduce training time with MXNet and Horovod on Amazon SageMaker. We also demonstrate how to further improve performance with advanced sections on Horovod autotuning, Horovod Timeline, Horovod Fusion, and MXNet optimization.

Distributed training

Distributed training of neural networks for computer vision (CV) and natural language processing (NLP) applications has become ubiquitous. With Apache MXNet, you only need to modify a few lines of code to enable distributed training.

Distributed training allows you to reduce training time by scaling horizontally. The goal is to split training tasks into independent subtasks and run these across multiple devices. There are primarily two approaches for training in parallel:

  • Data parallelism – You distribute the data and share the model across multiple compute resources
  • Model parallelism – You distribute the model and share transformed data across multiple compute resources.

In this post, we focus on data parallelism. Specifically, we discuss how Horovod and MXNet allow you to train efficiently on Amazon SageMaker.

Horovod overview

Horovod is an open-source distributed deep learning framework. It uses efficient inter-GPU and inter-node communication methods such as NVIDIA Collective Communications Library (NCCL) and Message Passing Interface (MPI) to distribute and aggregate model parameters between workers. Horovod makes distributed deep learning fast and easy by using a single-GPU training script and scaling it across many GPUs in parallel. It’s built on top of the ring-allreduce communication protocol. This approach allows each training process (such as a process running on a single GPU device) to talk to its peers and exchange gradients by averaging (called reduction) on a subset of gradients. The following diagram illustrates how ring-allreduce works.

Fig. 1 The ring-allreduce algorithm allows worker nodes to average gradients and disperse them to all nodes without the need for a parameter server ( source)

Apache MXNet is integrated with Horovod through the distributed training APIs defined in Horovod, and you can convert the non-distributed training by following the higher level code skeleton, which we show in this post.

Although this greatly simplifies the process of using Horovod, you must consider other complexities. For example, you may need to install additional software and libraries to resolve your incompatibilities for making distributed training work. Horovod requires a certain version of Open MPI, and if you want to use high-performance training on NVIDIA GPUs, you need to install NCCL libraries. These complexities are amplified when you scale across multiple devices, because you need to make sure all the software and libraries in the new nodes are properly installed and configured. Amazon SageMaker includes all the required libraries to run distributed training with MXNet and Horovod. Prebuilt Amazon SageMaker Docker images come with popular open-source deep learning frameworks and pre-configured CUDA, cuDNN, MPI, and NCCL libraries. Amazon SageMaker manages the difficult process of properly installing and configuring your cluster. Amazon SageMaker and MXNet simplify training with Horovod by managing the complexities to support distributed training at scale.

Test problem and dataset

To benchmark the efficiencies realized by Horovod, we trained the notoriously resource-intensive model architectures Mask-RCNN and Faster-RCNN. These model architectures were first introduced in 2018 and 2016, respectively, and are currently considered the baseline model architectures for two popular CV tasks: instance segmentation (Mask-RCNN) and object detection (Faster-RCNN). Mask-RCNN builds upon Faster-RCNN by adding a mask for segmentation. Apache MXNet provides pre-built Mask-RCNN and Faster-RCNN models as part of the GluonCV model zoo, simplifying the process of training these models.

To train our object detection and instance segmentation models, we used the popular COCO2017 dataset. This dataset provides more than 200,000 images and their corresponding labels. The COCO2017 dataset is considered an industry standard for benchmarking CV models.

GluonCV is a CV toolkit built on top of MXNet. It provides out-of-the-box support for various CV tasks, including data loading and preprocessing for many common algorithms available within its model zoo. It also provides a tutorial on getting the COCO2017 dataset.

To make this process replicable for Amazon SageMaker users, we show an entire end-to-end process for training Mask-RCNN and Faster-RCNN with Horovod and MXNet. To begin, we first open the Jupyter environment in your Amazon SageMaker notebook and use the conda_mxnet_p36 kernel. Next, we install the required Python packages:

! pip install gluoncv
! pip install pycocotools

We use the GluonCV toolkit to download the COCO2017 dataset onto our Amazon SageMaker notebook:

import gluoncv as gcv
gcv.utils.download('https://gluon-cv.mxnet.io/_downloads/b6ade342998e03f5eaa0f129ad5eee80/mscoco.py',path='./')
#Now to install the dataset. Warning, this may take a while
! python mscoco.py --download-dir data

We upload COCO2017 to the specified Amazon Simple Storage Service (Amazon S3) bucket using the following command:

! aws s3 cp './data/' s3://<INSERT BUCKET NAME>/ --recursive –quiet

Training script with Horovod Support

To use Horovod in your training script, you only need to make a few modifications. For code samples and instructions, see Horovod with MXNet. In addition, many GluonCV models in the model zoo have scripts that already support Horovod out of the box. In this section, we review the key changes required for Horovod to correctly work on Amazon SageMaker with Apache MXNet. The following code follows directly from the Horovod documentation:

import mxnet as mx
import horovod.mxnet as hvd
from mxnet import autograd

# Initialize Horovod, this has to be done first as it activates Horovod.
hvd.init()

# GPU setup 
context =[mx.gpu(hvd.local_rank())] #local_rank is the specific gpu on that 
# instance
num_gpus = hvd.size() #This is how many total GPUs you will be using.

#Typically, in your data loader you will want to shard your dataset. For 
# example, in the train_mask_rcnn.py script 
train_sampler = \
        gcv.nn.sampler.SplitSortedBucketSampler(...,
                                                num_parts=hvd.size() if args.horovod else 1,
                                                part_index=hvd.rank() if args.horovod else 0)

#Normally, we would shard the dataset first for Horovod.
val_loader = mx.gluon.data.DataLoader(dataset, len(ctx), ...) #... is for your # other arguments

    
# You build and initialize your model as usual.
model = ...

# Fetch and broadcast the parameters.
params = model.collect_params()
if params is not None:
    hvd.broadcast_parameters(params, root_rank=0)

# Create DistributedTrainer, a subclass of gluon.Trainer.
trainer = hvd.DistributedTrainer(params, opt)

# Create loss function and train your model as usual. 

Training job configuration

The Amazon SageMaker MXNet estimator class supports Horovod via the distributions parameter. We need to add a predefined mpi parameter with the enabled flag, and define the following additional parameters:

  • processes_per_host (int) – Number of processes MPI should launch on each host. This parameter is usually equal to the number of GPU devices available on any given instance.
  • custom_mpi_options (str) – Any custom mpirun flags passed in this field are added to the mpirun command and run by Amazon SageMaker for Horovod training.

The follow example code initializes the distributions parameters:

distributions = {'mpi': {
                    'enabled': True,
                    'processes_per_host': 8, #Each instance has 8 gpus
			'custom_mpi_options': '-verbose --NCCL_DEBUG=INFO'
                        }
                }

Next, we need to configure other parameters of our training job, such as hyperparameters, and the input and output Amazon S3 locations. To do this, we use the MXNet estimator class from the Amazon SageMaker Python SDK:

#Define the basic configuration of your Horovod-enabled Sagemaker training 
# cluster.
num_instances = 2 # How many nodes you want to use.
instance_family = 'ml.p3dn.24xlarge' # Which instance type you want to use.


estimator = MXNet(
                entry_point=<source_name>.py,         #Script entry point.
                source_dir='./source',                #Script Location
                role=role, 
                train_instance_type=instance_family,
                train_instance_count=num_instances,
                framework_version='1.6.0',            #MXNet version.
                train_volume_size=100,                #Size for the dataset.
                py_version='py3',                     #Python version.
                hyperparameters=hyperparameters,
                distributions=distributions           #For use with Horovod.

We’re now ready to start our first Horovod-powered training job with the following command:

            estimator.fit(
                {'data':'s3://' + bucket_name + '/data'}
            )

Results

We performed these benchmarks on two similar GPU instance types: the p3.16xlarge and the more powerful p3dn.24xlarge. Although both have 8 NVIDIA V100 GPUs, the latter instance is designed with distributed training in mind. In addition to a high-throughput network interface amenable to the inter-node data transfers inherent in distributed training, the p3dn.24xlarge boasts more compute and additional memory over the p3.16xlarge.

We ran benchmarks in three different use cases. In the first and second use cases, we trained the models on a single instance using all 8 local GPUs, to demonstrate the efficiencies gained by using Horovod to manage local training across multiple GPUs. In the third use case, we used Horovod for distributed training across multiple instances, each with 8 local GPUs, to demonstrate the additional efficiency increase by scaling horizontally.

The following table summarizes the time and accuracy for each training scenario.

Model Instance Type 1 Instance, 8 GPUs w/o Horovod 1 Instance, 8 GPUs with Horovod 3 Instances, 8 GPUs with Horovod
Training Time Accuracy Training Time Accuracy Training Time Accuracy
Faster RCNN p3.16xlarge 35 h 47 m 37.6 8 h 26 m 37.5 4 h 58 m 37.4
Faster RCNN p3dn.24xlarge 32 h 24 m 37.5 7 h 27 m 37.5 3 h 37 m 37.3
Mask RCNN p3.16xlarge 45 h 28 m

38.5 (bbox)

34.8 (segm)

10 h 28 m

34.4 (bbox)

31.3 (segm)

5 h 34 m

36.8 (bbox)

33.5 (segm)

Mask RCNN p3dn.24xlarge 40 h 49 m

38.3 (bbox)

34.8 (segm)

8 h 41 m 34.6 (bbox)
31.5 (segm)
4 h 2 m

37.0 (bbox)

33.4 (segm)

Table 1: Training time and accuracy are shown for three different training scenarios.

As expected, when using Horovod to distribute training across multiple instances, the time to convergence is significantly reduced. Additionally, even when training on a single instance, Horovod substantially increases training efficiency when using multiple local GPUs, as compared to the default parameter-server approach. Horovod’s simplified APIs and abstractions enable you to unlock efficiency gains when training across multiple GPUs, both on a single machine or many. For more information about using this approach for scaling batch size and learning rate, see Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour.

With the improvement in training time enabled by Horovod and Amazon SageMaker, you can focus more on improving your algorithms instead of waiting for jobs to finish training. You can train in parallel across multiple instances with marginal impact to mean Average Precision (mAP).

Optimizing Horovod training

Horovod provides several additional utilities that allow you to analyze and optimize training performance.

Horovod autotuning

Finding the optimal combinations of parameters for a given combination of model and cluster size may require several iterations of trial and error.

The autotune feature allows you to automate this trial-and-error activity within a single training job, and uses Bayesian optimization to search through the parameter space for the most performant combination of parameters. Horovod searches for the best combination of parameters in the first cycles of a training job. When it defines the best combination, Horovod writes it in the autotune log and uses this combination for the remainder of the training job. For more information, see Autotune: Automated Performance Tuning.

To enable autotuning and capture the search log, pass the following parameters in your MPI configuration:

{
    'mpi':
    {
        'enabled': True,
        'custom_mpi_options': '-x HOROVOD_AUTOTUNE=1 -x         HOROVOD_AUTOTUNE_LOG=/opt/ml/output/autotune_log.csv'
    }
}

Horovod Timeline

Horovod Timeline is a report available after training completion that captures all activities in the Horovod ring. This is useful to understand which operations are taking the longest and identify optimization opportunities. For more information, see Analyze Performance.

To generate a timeline file, add the following parameters in your MPI command:

{
    'mpi':
    {
        'enabled': True,
        'custom_mpi_options': '-x HOROVOD_TIMELINE=/opt/ml/output/timeline.json'
    }
}

The /opt/ml/output is a directory with a specific purpose. After the training job is complete, Amazon SageMaker automatically archives all files in this directory and uploads it to an Amazon S3 location that you define in the Python Amazon SageMaker SDK API.

Tensor Fusion

The Tensor Fusion feature allows you to perform batch allreduce operations at training time. This typically results in better overall performance. For more information, see Tensor Fusion. By default, Tensor Fusion is enabled and has a buffer size of 64 MB. You can modify buffer size using a custom MPI flag as follows (for our use case, we override the default 64 MB buffer value with 32 MB):

{
    'mpi':
    {
        'enabled': True,
        'custom_mpi_options': '-x HOROVOD_FUSION_THRESHOLD=33554432'
    }
}

You can also adjust batch cycles using the HOROVOD_CYCLE_TIME parameter. Cycle time is defined in milliseconds. See the following code:

{
    'mpi':
    {
        'enabled': True,
        'custom_mpi_options': '-x HOROVOD_CYCLE_TIME=5'
    }
}

Optimizing MXNet models

Another optimization technique is related to optimizing the MXNet model itself. We recommend running the code with os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '1'. Then you can copy the best OS environment variables for future training. In our testing, we found the following to be the best results:

os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
os.environ['MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF'] = '26'
os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD'] = '999'
os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD'] = '25'
os.environ['MXNET_GPU_COPY_NTHREADS'] = '1'
os.environ['MXNET_OPTIMIZER_AGGREGATION_SIZE'] = '54'

Conclusion

In this post, we demonstrated how to reduce training time with Horovod and Apache MXNet on Amazon SageMaker. You can train your model out of the box without worrying about any additional complexities.

For more information about deep learning and MXNet, see the MXNet crash course and Dive into Deep Learning book. You can also get started on the MXNet website and MXNet GitHub examples directory. If you’re new to distributed training and want to dive deeper, we highly recommend reading the paper Horovod: fast and easy distributed deep learning in TensorFlow. If you use the AWS Deep Learning Containers and AWS Deep Learning AMIs, you can learn how to set up this workflow in that environment in our recent post How to run distributed training using Horovod and MXNet on AWS DL containers and AWS Deep Learning AMIs.


About the Authors

Vadim Dabravolski is AI/ML Solutions Architect with FinServe team. He is focused on Computer Vision and NLP technologies and how to apply them to business use cases. After hours Vadim enjoys jogging in NYC boroughs, reading non-fiction (business, history, culture, politics, you name it), and rarely just doing nothing.

Corey Barrett is a Data Scientist in the Amazon ML Solutions Lab. As a member of the ML Solutions Lab, he leverages Machine Learning and Deep Learning to solve critical business problems for AWS customers. Outside of work, you can find him enjoying the outdoors, sipping on scotch, and spending time with his family.

Chaitanya Bapat is a Software Engineer with the AWS Deep Learning team. He works on Apache MXNet and integrating the framework with Amazon Sagemaker, DLC and DLAMI. In his spare time, he loves watching sports and enjoys reading books and learning Spanish.

Karan Jariwala is a Software Development Engineer on the AWS Deep Learning team. His work focuses on training deep neural networks. Outside of work, he enjoys hiking, swimming, and playing tennis.