AWS Machine Learning Blog

The importance of hyperparameter tuning for scaling deep learning training to multiple GPUs

Parallel processing with multiple GPUs is an important step in scaling training of deep models. In each training iteration, typically a small subset of the dataset, called a mini-batch, is processed. When a single GPU is available, processing of the mini-batch in each training iteration is handled by this GPU. When training with multiple GPUs, the mini-batch is split across available GPUs to evenly spread the processing load. To ensure that you fully use each GPU, you must increase the mini-batch size linearly with each additional GPU. Mini-batch size has an impact not only on training speed, but also on the quality of the trained model. As you increase the mini-batch size, it is important to tune other hyperparameters to ensure faster training with similar model quality.

Multi-GPU and distributed training with Gluon

With the vast amount of data required by modern deep learning models, scaling to multiple GPUs and distributed machines can be a significant time saver for both research and production. With services such as Amazon SageMaker and Amazon Elastic Compute Cloud (Amazon EC2), setting up distributed training with several hundred GPUs is not only pain free, but also very economical because you only pay for the exact usage, and you don’t have to maintain an expensive underutilized hardware fleet.

Apache MXNet is a flexible and efficient deep learning platform. It’s particularly suitable for multi-GPU and distributed training across multiple hosts. The Gluon library in Apache MXNet provides a clear, concise, and simple API for deep learning. Tutorials on Training on multiple GPUs with gluon and Distributed training with multiple machines demonstrate the ease of setting up multi-GPU and distributed training.

Training hyperparameters

Training hyperparameters constitute all of the parameters that are not learnable by gradient descent, but have an impact on the final model quality. These parameters include optimization parameters, such as learning rate and momentum, augmentation parameters, such as random color shift amount, and any other non-learnable parameter.

The MXNet Gluon API makes it easy to modify training code to take advantage of multiple GPUs by seamlessly creating model parameters on all GPUs and providing the gluon.utils.split_and_load() function to split the data between multiple GPUs. Although upgrading the training code to use multiple GPUs is easy, the process requires tuning hyperparameters and applying optimization tricks to take advantage of the larger computing resources without sacrificing model quality.

Consequences of increasing mini-batch size

When moving from training on a single GPU to training on multiple GPUs, a good heuristic is to increase the mini-batch size by multiplying by the number of GPUs to keep the mini-batch size per GPU constant. For example, if a mini-batch size of 128 keeps a single GPU fully utilized, you should increase to a mini-batch size of 512 when using four GPUs. Although with a larger mini-batch size the throughput of data increases, the training often does not converge much faster in clock-time. With a larger mini-batch size, the amount of noise in the gradient from batch-to-batch decreases, allowing the stochastic gradient descent to step closer in the direction of the optima. But by keeping the learning rate the same, the average step size is not changed, which leads to only a slight saving in the number of steps required to converge.

A bigger problem with increasing the mini-batch size is its negative impact on model quality. Increasing the mini-batch size lowers the variance of the gradient of the loss function. In theory, in a convex optimization scenario, lower gradient variance results in better optimization. However, practice has shown that increasing the mini-batch size results in models with poor generalization (see Keskar et al). This topic is an active area of research in the field of deep learning and several theories have been explored. Although researchers have not settled on a single theory to explain the behavior, the negative impact of larger mini-batch size on model generalization is well demonstrated in literature.

One of the leading theories argues that the non-convex surface of the loss function contains many local minima and saddle points. With a smaller mini-batch, the gradient of the loss per mini-batch is noisier and can result in the optimization process bouncing out of a local minimum or saddle point. A large mini-batch, however, results in a gradient with less stochasticity and optimization may get stuck in a local minimum or on a saddle point.

Hyperparameter tuning

To increase the rate of convergence with larger mini-batch size, you must increase the learning rate of the SGD optimizer. However, as demonstrated by Keskar et al, optimizing a network with large learning rate is difficult. Some optimization tricks have proven effective in addressing this difficulty (see Goyal et al). Primarily, these techniques involve priming (also known as warming up) the network by using a small learning rate for a few epochs.

Interestingly enough, the technique of increasing the learning rate with an increase in mini-batch size has been shown to reduce the impact of having a large mini-batch on model quality. Goyal et al show that for a ResNet-50 model, mini-batch size can be scaled to 8192 without any reduction in validation error if the learning rate is also scaled by the same relative amount. This mini-batch size allows the training of a ResNet-50 model on the ImageNet dataset to finish in just 1 hour. However, the validation error grows very quickly after the mini-batch size is scaled beyond 8192.

Depending on the network architecture and training hyperparameters, other hyperparameters in addition to mini-batch size and learning rate may need to be tuned and other optimization tricks other than “priming” may be required to successfully scale training to multiple GPUs and distributed settings.

Large mini-batch example

In this example, I demonstrate the negative impact of increasing the mini-batch size without changing other hyperparameters by training a ResNet-18 model using the CIFAR-10 dataset for image classification task. I use a machine with 16 GPUs, but the concept applies to any configuration where mini-batch size is increased by a large factor.

System configuration

The system I am using in this example is an Amazon EC2 p2.16xlarge instance. I am using MXNet package mxnet_cu90mkl version 1.1.0, installed via pip, which is the latest version of MXNet at the time of this writing. The minimum required MXNet version for this example is version 1.0.0. Instructions on installing MXNet can be found in this page.

I use Deep Learning AMI (Ubuntu) Version 7.0 (ami-139a476c) as the machine image for my EC2 instance. This Amazon Machine Image (AMI) is the latest Deep Learning AMI available on AWS Marketplace at the time of the writing. Here is a link for getting started with DLAMI. Alternatively you can use the Amazon SageMaker service, which has MXNet pre-installed. If you are using Amazon SageMaker notebooks, you might need to upgrade the MXNet package because as of this writing the notebook starts with version 0.12.1.

CIFAR-10 dataset

The machine learning problem addressed in this example is an image classification task. The CIFAR-10 dataset consists of 60,000 color images with pixel resolution of 32×32, which are split into a 50,000 training set and 10,000 test set. For more information on the CIFAR-10 dataset, please refer to Chapter 3 of Learning Multiple Layers of Feature from Tiny Images [3] for a description of the dataset and the methodology followed when collecting the data.

We use a number of transformations to normalize pixel values and add data augmentation for training. You normalize pixel values by first normalizing values to the [0, 1] range, subtracting the global mean of the training dataset and scaling by the standard deviation of the dataset. Random crop and mirror augmentation is only applied to the training dataset.

from __future__ import print_function
from time import time
import mxnet as mx
from mxnet import autograd, gluon, nd
import numpy as np

print("MXNet Version:", mx.__version__)

def pad_3d(data, pad_width):
    data = data.reshape((1, 1) + data.shape)
    data = nd.pad(data, pad_width=(0, 0, 0, 0) + pad_width,
                  mode='constant', constant_value=0)
    data = data.reshape(data.shape[2:])
    return data


def transform(data, label, rand_aug):
    data = data.astype(np.float32) / 255
    if rand_aug:
        data = pad_3d(data, (4, 4, 4, 4, 0, 0))
    auglist = mx.image.CreateAugmenter(
        data_shape=(3, 32, 32),
        rand_mirror=rand_aug,
        rand_crop=rand_aug,
        mean=np.array([0.4914, 0.4822, 0.4465]),
        std=np.array([0.2023, 0.1994, 0.2010]))
    for aug in auglist:
        data = aug(data)
    return nd.transpose(data, (2, 0, 1)), nd.array([label]).astype(np.float32)

MXNet Version: 1.1.0

The data_loader() function loads the training or test dataset depending on the train parameter. Note that random augmentation is disabled for the test dataset.

def data_loader(train, batch_size, num_workers):
    dataset = gluon.data.vision.CIFAR10(
        train=train,
        transform=lambda x, y: transform(x, y, rand_aug=train))
    return gluon.data.DataLoader(
        dataset, batch_size, shuffle=train, num_workers=num_workers)

We use the Gluon model zoo to create a ResNet-18 network for classifying 10 classes in the CIFAR-10 dataset. Training is done using cross entropy loss on a softmax output. Note that the network is hybridized to optimize the performance of the computational graph. Please refer to the tutorial on hybridizing a gluon network for more information.

net = mx.gluon.model_zoo.vision.resnet18_v1(pretrained=False, classes=10)
net.hybridize()
softmax_ce = gluon.loss.SoftmaxCrossEntropyLoss()

For calculating accuracy, target labels don’t need to be loaded to GPU because the accuracy calculation between predicted labels and target labels is performed in CPU. We do, however, use all available GPUs to calculate the predictions. Please refer to the tutorial on using multiple GPUs in gluon for more information.

def evaluate_accuracy(data_iterator, net):
    acc = mx.metric.Accuracy()
    for data, label in data_iterator:
        data = gluon.utils.split_and_load(data, ctx)
        label = gluon.utils.split_and_load(label, [mx.cpu() for _ in ctx])
        acc.update(
            preds=[nd.argmax(net(d), axis=1, keepdims=True) for d in data],
            labels=label)
    return acc.get()[1]

The training loop splits the training data across GPUs in a ctx list, performs forward and backward operation on the network, and calculates the average loss per epoch. At the beginning of each epoch, the learning rate is adjusted based on the learning rate schedule passed into the train() function as lr_sched. The learning rate schedule is a dictionary from epoch number to learning rate, starting from epoch 0.

def train(epochs, batch_size, lr_sched):
    num_workers = 64
    train_data = data_loader(train=True,
                             batch_size=batch_size,
                             num_workers=num_workers)
    test_data = data_loader(train=False,
                            batch_size=80,
                            num_workers=num_workers)

    # Initialize parameters randomly
    net.collect_params().initialize(mx.init.Xavier(magnitude=2.24),
                                    ctx=ctx,
                                    force_reinit=True)
    trainer = gluon.Trainer(
        net.collect_params(),
        'sgd',
        {'learning_rate': lr_sched[0], 'momentum': 0.9, 'wd': 0.0001})

    train_start = time()
    avg_loss = nd.zeros((1,), ctx=ctx[0])
    for e in range(epochs):
        if e in lr_sched:
            trainer.set_learning_rate(lr_sched[e])
        avg_loss *= 0  # Zero average loss of each epoch
        for i, td in enumerate(train_data):
            if i == 0:
                e_start = time()
            data, label = td
            data = gluon.utils.split_and_load(data, ctx)
            label = gluon.utils.split_and_load(label, ctx)
            # Wait for completion of previous iteration to
            # avoid unnecessary memory allocation
            nd.waitall()
            with autograd.record():
                output = [net(x) for x in data]
                loss = [softmax_ce(o, l) for o, l in zip(output, label)]
            for l in loss:
                l.backward()
            trainer.step(batch_size)
            # Calculate average loss
            for l in loss:
                avg_loss += l.mean().as_in_context(avg_loss.context)

        avg_loss /= (i * len(ctx))
        epoch_time = time() - e_start

        if e < 6 or (e + 1) % 5 == 0 or (e + 1) == epochs:
            print("\tEPOCH {:2}: train loss {:4.2} | batch {:4} | "
                  "lr {:5.3f} | Time per epoch {:5.2f} seconds".format(
                      e, avg_loss.asscalar(), batch_size,
                      trainer.learning_rate, epoch_time))
    train_accuracy = evaluate_accuracy(train_data, net)
    test_accuracy = evaluate_accuracy(test_data, net)
    print("Training time {:6.2f} seconds | train accuracy {:6.4} | "
          "test accuracy {:6.4}".format(
              time() - train_start, train_accuracy, test_accuracy))

Single GPU (mini-batch size of 128)

First, we train the network with a single GPU using a mini-batch size of 128, learning of 0.1, momentum of 0.9, and weight decay of 0.0001.

ctx = [mx.gpu(0)]
train(epochs=45, batch_size=128, lr_sched={0: 0.1, 35: 0.05, 40: 0.02, 44: 0.01})

    EPOCH  0: train loss  2.3 | batch  128 | lr 0.100 | Time per epoch 24.11 seconds
    EPOCH  1: train loss  1.6 | batch  128 | lr 0.100 | Time per epoch 20.55 seconds
    EPOCH  2: train loss  1.4 | batch  128 | lr 0.100 | Time per epoch 20.75 seconds
    EPOCH  3: train loss  1.3 | batch  128 | lr 0.100 | Time per epoch 20.74 seconds
    EPOCH  4: train loss  1.2 | batch  128 | lr 0.100 | Time per epoch 20.72 seconds
    EPOCH  5: train loss  1.1 | batch  128 | lr 0.100 | Time per epoch 20.69 seconds
    EPOCH  9: train loss 0.86 | batch  128 | lr 0.100 | Time per epoch 20.60 seconds
    EPOCH 14: train loss 0.72 | batch  128 | lr 0.100 | Time per epoch 20.91 seconds
    EPOCH 19: train loss 0.63 | batch  128 | lr 0.100 | Time per epoch 20.85 seconds
    EPOCH 24: train loss 0.57 | batch  128 | lr 0.100 | Time per epoch 20.83 seconds
    EPOCH 29: train loss 0.54 | batch  128 | lr 0.100 | Time per epoch 20.77 seconds
    EPOCH 34: train loss  0.5 | batch  128 | lr 0.100 | Time per epoch 20.79 seconds
    EPOCH 39: train loss 0.37 | batch  128 | lr 0.050 | Time per epoch 20.73 seconds
    EPOCH 44: train loss 0.25 | batch  128 | lr 0.010 | Time per epoch 20.59 seconds
Training time 999.79 seconds | train accuracy 0.9261 | test accuracy 0.8475

Multi-GPU (mini-batch size of 2048)

Now let’s increase the mini-batch size by a factor of 16 across 16 GPUs to accelerate training time. We will use mini-batch size of 2048 and keep all other hyperparameters the same.

ctx = [mx.gpu(i) for i in range(16)]
train(epochs=45, batch_size=2048, lr_sched={0: 0.1, 35: 0.05, 40: 0.02, 44: 0.01})
    EPOCH  0: train loss  3.5 | batch 2048 | lr 0.100 | Time per epoch 22.35 seconds
    EPOCH  1: train loss  2.2 | batch 2048 | lr 0.100 | Time per epoch  6.36 seconds
    EPOCH  2: train loss  2.0 | batch 2048 | lr 0.100 | Time per epoch  6.42 seconds
    EPOCH  3: train loss  1.9 | batch 2048 | lr 0.100 | Time per epoch  6.15 seconds
    EPOCH  4: train loss  1.8 | batch 2048 | lr 0.100 | Time per epoch  6.39 seconds
    EPOCH  5: train loss  1.7 | batch 2048 | lr 0.100 | Time per epoch  6.47 seconds
    EPOCH  9: train loss  1.4 | batch 2048 | lr 0.100 | Time per epoch  6.40 seconds
    EPOCH 14: train loss  1.1 | batch 2048 | lr 0.100 | Time per epoch  6.04 seconds
    EPOCH 19: train loss 0.99 | batch 2048 | lr 0.100 | Time per epoch  6.43 seconds
    EPOCH 24: train loss 0.84 | batch 2048 | lr 0.100 | Time per epoch  6.50 seconds
    EPOCH 29: train loss 0.75 | batch 2048 | lr 0.100 | Time per epoch  6.27 seconds
    EPOCH 34: train loss 0.67 | batch 2048 | lr 0.100 | Time per epoch  6.39 seconds
    EPOCH 39: train loss 0.57 | batch 2048 | lr 0.050 | Time per epoch  6.05 seconds
    EPOCH 44: train loss 0.51 | batch 2048 | lr 0.010 | Time per epoch  6.15 seconds
Training time 977.40 seconds | train accuracy 0.8379 | test accuracy 0.7829

You can see that the time per epoch improved significantly, but we were not able to achieve the same test accuracy. You might notice that the total training time didn’t improve as much as the time per epoch. This is because the overhead in launching the DataLoader for each epoch, as well as calculating train and test accuracy are constant. Typically when the dataset is much larger than our dataset (e.g., ImageNet dataset), this overhead is negligible.

Now let’s introduce a simple 5-epoch warm-up stage and increase our learning rate linearly by a factor of 16 (from 0.1 to 1.6) and from that point follow a similar learning schedule as before, only with the learning rate scaled.

train(epochs=45,
      batch_size=2048,
      lr_sched={
          0: 0.1,
          1: 0.1 + 0.3,
          2: 0.1 + 0.6,
          3: 0.1 + 0.9,
          4: 0.1 + 1.2,
          5: 0.1 + 1.5,
          35: 0.05 * 16,
          40: 0.02 * 16,
          44: 0.01 * 16})

    EPOCH  0: train loss  4.1 | batch 2048 | lr 0.100 | Time per epoch  7.11 seconds
    EPOCH  1: train loss  2.8 | batch 2048 | lr 0.400 | Time per epoch  6.48 seconds
    EPOCH  2: train loss  2.2 | batch 2048 | lr 0.700 | Time per epoch  6.41 seconds
    EPOCH  3: train loss  1.8 | batch 2048 | lr 1.000 | Time per epoch  6.69 seconds
    EPOCH  4: train loss  1.6 | batch 2048 | lr 1.300 | Time per epoch  6.42 seconds
    EPOCH  5: train loss  1.6 | batch 2048 | lr 1.600 | Time per epoch  6.37 seconds
    EPOCH  9: train loss  1.2 | batch 2048 | lr 1.600 | Time per epoch  6.50 seconds
    EPOCH 14: train loss 0.89 | batch 2048 | lr 1.600 | Time per epoch  6.44 seconds
    EPOCH 19: train loss 0.73 | batch 2048 | lr 1.600 | Time per epoch  6.46 seconds
    EPOCH 24: train loss 0.65 | batch 2048 | lr 1.600 | Time per epoch  6.41 seconds
    EPOCH 29: train loss 0.58 | batch 2048 | lr 1.600 | Time per epoch  6.53 seconds
    EPOCH 34: train loss 0.56 | batch 2048 | lr 1.600 | Time per epoch  6.46 seconds
    EPOCH 39: train loss  0.4 | batch 2048 | lr 0.800 | Time per epoch  6.51 seconds
    EPOCH 44: train loss 0.29 | batch 2048 | lr 0.160 | Time per epoch  6.48 seconds
Training time 916.39 seconds | train accuracy 0.9115 | test accuracy 0.8416

You can see that by scaling the learning rate by the same factor as the mini-batch size, the test accuracy improved significantly, almost achieving the same accuracy as the mini-batch size of 128. Further hyperparameter optimization can get the test accuracy even closer. Note that hyperparameter optimization can also be applied to the case with a mini-batch size of 128 to maximize the achievable test accuracy.

Amazon SageMaker

As mentioned previously, the above code could be run from an Amazon EC2 node with the Deep Learning AMI, or on an Amazon SageMaker notebook instance.  Given the large space of possible hyperparameter combinations, manual tuning of hyperparameters can be a very time-consuming task. Amazon SageMaker provides a Hyperparameter Optimization (HPO) tool (currently in preview as of this writing) that uses an intelligent algorithm for efficient automatic exploration of a large hyperparameter space using multiple training clusters in parallel to further improve the model.

Conclusion

Scaling the training of a model to multiple GPUs is a logical and economical step in accelerating research and reducing delays in retraining production models. However, you should be aware that simply scaling the mini-batch size by the added computing capacity doesn’t result in faster training and can result in a lower-quality trained model. Hyperparameter tuning is a crucial step in maintaining model quality with increased mini-batch size. Certain optimization tricks, such as priming the network, may be necessary for successful training. The Amazon SageMaker HPO tool can be used to efficiently explore the hyperparameter space.

References

[1] On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima, Keskar et al., 2016
[2] Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour, Goyal et al, 2017
[3] Learning Multiple Layers of Features from Tiny Images, Alex Krizhevsky, 2009.


About the author

Sina Afrooze is an AWS Software Engineer, focusing on application of Apache MXNet in deep learning for artificial intelligence. His domain expertise is in digital imaging and computer vision. He enjoys helping AWS customers achieve scale in deep learning solutions using Apache MXNet.