Fine-tuning a PyTorch BERT model and deploying it with Amazon Elastic Inference on Amazon SageMaker
Text classification is a technique for putting text into different categories, and has a wide range of applications: email providers use text classification to detect spam emails, marketing agencies use it for sentiment analysis of customer reviews, and discussion forum moderators use it to detect inappropriate comments.
In the past, data scientists used methods such as tf-idf, word2vec, or bag-of-words (BOW) to generate features for training classification models. Although these techniques have been very successful in many natural language processing (NLP) tasks, they don’t always capture the meanings of words accurately when they appear in different contexts. Recently, we see increasing interest in using Bidirectional Encoder Representations from Transformers (BERT) to achieve better results in text classification tasks, due to its ability to encode the meaning of words in different contexts more accurately.
Amazon SageMaker is a fully managed service that provides developers and data scientists the ability to build, train, and deploy machine learning (ML) models quickly. Amazon SageMaker removes the heavy lifting from each step of the ML process to make it easier to develop high-quality models. The Amazon SageMaker Python SDK provides open-source APIs and containers that make it easy to train and deploy models in Amazon SageMaker with several different ML and deep learning frameworks.
Our customers often ask for quick fine-tuning and easy deployment of their NLP models. Furthermore, customers prefer low inference latency and low model inference cost. Amazon Elastic Inference enables attaching GPU-powered inference acceleration to endpoints, which reduces the cost of deep learning inference without sacrificing performance.
This post demonstrates how to use Amazon SageMaker to fine-tune a PyTorch BERT model and deploy it with Elastic Inference. The code from this post is available in the GitHub repo. For more information about BERT fine-tuning, see BERT Fine-Tuning Tutorial with PyTorch.
What is BERT?
First published in November 2018, BERT is a revolutionary model. First, one or more words in sentences are intentionally masked. BERT takes in these masked sentences as input and trains itself to predict the masked word. In addition, BERT uses a next sentence prediction task that pretrains text-pair representations.
BERT is a substantial breakthrough and has helped researchers and data engineers across the industry achieve state-of-art results in many NLP tasks. BERT offers representation of each word conditioned on its context (rest of the sentence). For more information about BERT, see BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.
One of the biggest challenges data scientists face for NLP projects is lack of training data; you often have only a few thousand pieces of human-labeled text data for your model training. However, modern deep learning NLP tasks require a large amount of labeled data. One way to solve this problem is to use transfer learning.
Transfer learning is an ML method where a pretrained model, such as a pretrained ResNet model for image classification, is reused as the starting point for a different but related problem. By reusing parameters from pretrained models, you can save significant amounts of training time and cost.
BERT was trained on BookCorpus and English Wikipedia data, which contains 800 million words and 2,500 million words, respectively . Training BERT from scratch would be prohibitively expensive. By taking advantage of transfer learning, you can quickly fine-tune BERT for another use case with a relatively small amount of training data to achieve state-of-the-art results for common NLP tasks, such as text classification and question answering.
In this post, we walk through our dataset, the training process, and finally model deployment.
We use an Amazon SageMaker notebook instance for running the code. For more information about using Jupyter notebooks on Amazon SageMaker, see Using Amazon SageMaker Notebook Instances or Getting Started with Amazon SageMaker Studio.
Problem and dataset
For this post, we use Corpus of Linguistic Acceptability (CoLA), a dataset of 10,657 English sentences labeled as grammatical or ungrammatical from published linguistics literature. In our notebook, we download and unzip the data using the following code:
In the training data, the only two columns we need are the sentence itself and its label:
If we print out a few sentences, we can see how sentences are labeled based on their grammatical completeness. See the following code:
We then split the dataset for training and testing before uploading both to Amazon S3 for use later. The SageMaker Python SDK provides a helpful function for uploading to Amazon S3:
For this post, we use the PyTorch-Transformers library, which contains PyTorch implementations and pretrained model weights for many NLP models, including BERT. See the following code:
Our training script should save model artifacts learned during training to a file path called
model_dir, as stipulated by the Amazon SageMaker PyTorch image. Upon completion of training, Amazon SageMaker uploads model artifacts saved in
model_dir to Amazon S3 so they are available for deployment. The following code is used in the script to save trained model artifacts:
We save this script in a file named
train_deploy.py, and put the file in a directory named
code/, where the full training script is viewable.
Because PyTorch-Transformer isn’t included natively in Amazon SageMaker PyTorch images, we have to provide a
requirements.txt file so that Amazon SageMaker installs this library for training and inference. A
requirements.txt file is a text file that contains a list of items that are installed by using
pip install. You can also specify the version of an item to install. To install PyTorch-Transformer, we add the following line to the requirements.txt file:
Training on Amazon SageMaker
We use Amazon SageMaker to train and deploy a model using our custom PyTorch code. The Amazon SageMaker Python SDK makes it easier to run a PyTorch script in Amazon SageMaker using its PyTorch estimator. After that, we can use the SageMaker Python SDK to deploy the trained model and run predictions. For more information about using this SDK with PyTorch, see Using PyTorch with the SageMaker Python SDK.
To start, we use the
PyTorch estimator class to train our model. When creating the estimator, we make sure to specify the following:
- entry_point – The name of the PyTorch script
- source_dir – The location of the training script and
- framework_version: The PyTorch version we want to use
The PyTorch estimator supports multi-machine, distributed PyTorch training. To use this, we just set
train_instance_count to be greater than 1. Our training script supports distributed training for only GPU instances.
After creating the estimator, we call
fit(), which launches a training job. We use the Amazon S3 URIs we uploaded the training data to earlier. See the following code:
After training starts, Amazon SageMaker displays training progress (as shown in the following code). Epochs, training loss, and accuracy on test data are reported:
We can monitor the training progress and make sure it succeeds before proceeding with the rest of the notebook.
After training our model, we host it on an Amazon SageMaker endpoint by calling
deploy on the PyTorch estimator. The endpoint runs an Amazon SageMaker PyTorch model server. We need to configure two components of the server: model loading and model serving. We implement these two components in our inference script
train_deploy.py. The complete file is available in the GitHub repo.
model_fn() is the function defined to load the saved model and return a model object that can be used for model serving. The SageMaker PyTorch model server loads our model by invoking
input_fn() deserializes and prepares the prediction input. In this use case, our request body is first serialized to JSON and then sent to model serving endpoint. Therefore, in
input_fn(), we first deserialize the JSON-formatted request body and return the input as a
torch.tensor, as required for BERT:
predict_fn() performs the prediction and returns the result. See the following code:
We take advantage of the prebuilt Amazon SageMaker PyTorch image’s default support for serializing the prediction result.
Deploying the endpoint
To deploy our endpoint, we call
deploy() on our PyTorch estimator object, passing in our desired number of instances and instance type:
We then configure the predictor to use
"application/json" for the content type when sending requests to our endpoint:
Finally, we use the returned predictor object to call the endpoint:
The predicted class is
1, which is expected because the test sentence is a grammatically correct sentence.
Deploying the endpoint with Elastic Inference
Selecting the right instance type for inference requires deciding between different amounts of GPU, CPU, and memory resources. Optimizing for one of these resources on a standalone GPU instance usually leads to underutilization of other resources. Elastic Inference solves this problem by enabling you to attach the right amount of GPU-powered inference acceleration to your endpoint. In March 2020, Elastic Inference support for PyTorch became available for both Amazon SageMaker and Amazon EC2.
To use Elastic Inference, we must first convert our trained model to TorchScript. For more information, see Reduce ML inference costs on Amazon SageMaker for PyTorch models using Amazon Elastic Inference.
We first download the trained model artifacts from Amazon S3. The location of the model artifacts is
estimator.model_data. We then convert the model to TorchScript using the following code:
Loading the TorchScript model and using it for prediction requires small changes in our model loading and prediction functions. We create a new script
deploy_ei.py that is slightly different from
For model loading, we use
torch.jit.load instead of the
BertForSequenceClassification.from_pretrained call from before:
For prediction, we take advantage of
torch.jit.optimized_execution for the final return statement:
deploy_ei.py script is available in the GitHub repo. With this script, we can now deploy our model using Elastic Inference:
We attach the Elastic Inference accelerator to our output by using the
Cleaning up resources
Remember to delete the Amazon SageMaker endpoint and Amazon SageMaker notebook instance created to avoid charges. See the following code:
In this post, we used Amazon SageMaker to take BERT as a starting point and train a model for labeling sentences on their grammatical completeness. We then deployed the model to an Amazon SageMaker endpoint, both with and without Elastic Inference acceleration. You can use this solution to tune BERT in other ways, or use other pretrained models provided by PyTorch-Transformers. For more about using PyTorch with Amazon SageMaker, see Using PyTorch with the SageMaker Python SDK.
 Yukun Zhu, Ryan Kiros, Rich Zemel, Ruslan Salakhutdinov, Raquel Urtasun, Antonio Torralba, and Sanja Fidler. 2015. Aligning books and movies: Towards story-like visual explanations by watching movies and reading books. In Proceedings of the IEEE international conference on computer vision, pages 19–27.
About the Authors
Qingwei Li is a Machine Learning Specialist at Amazon Web Services. He received his Ph.D. in Operations Research after he broke his advisor’s research grant account and failed to deliver the Nobel Prize he promised. Currently he helps customers in financial service and insurance industry build machine learning solutions on AWS. In his spare time, he likes reading and teaching.
David Ping is a Principal Solutions Architect with the AWS Solutions Architecture organization. He works with our customers to build cloud and machine learning solutions using AWS. He lives in the NY metro area and enjoys learning the latest machine learning technologies.
Lauren Yu is a Software Development Engineer at Amazon SageMaker. She works primarily on the SageMaker Python SDK, as well as toolkits for integrating PyTorch, TensorFlow, and MXNet with Amazon SageMaker. In her spare time, she enjoys playing viola in the Amazon Symphony Orchestra and Doppler Quartet.