AWS Machine Learning Blog

Fine-tuning a PyTorch BERT model and deploying it with Amazon Elastic Inference on Amazon SageMaker

Text classification is a technique for putting text into different categories, and has a wide range of applications: email providers use text classification to detect spam emails, marketing agencies use it for sentiment analysis of customer reviews, and discussion forum moderators use it to detect inappropriate comments.

In the past, data scientists used methods such as tf-idf, word2vec, or bag-of-words (BOW) to generate features for training classification models. Although these techniques have been very successful in many natural language processing (NLP) tasks, they don’t always capture the meanings of words accurately when they appear in different contexts. Recently, we see increasing interest in using Bidirectional Encoder Representations from Transformers (BERT) to achieve better results in text classification tasks, due to its ability to encode the meaning of words in different contexts more accurately.

Amazon SageMaker is a fully managed service that provides developers and data scientists the ability to build, train, and deploy machine learning (ML) models quickly. Amazon SageMaker removes the heavy lifting from each step of the ML process to make it easier to develop high-quality models. The Amazon SageMaker Python SDK provides open-source APIs and containers that make it easy to train and deploy models in Amazon SageMaker with several different ML and deep learning frameworks.

Our customers often ask for quick fine-tuning and easy deployment of their NLP models. Furthermore, customers prefer low inference latency and low model inference cost. Amazon Elastic Inference enables attaching GPU-powered inference acceleration to endpoints, which reduces the cost of deep learning inference without sacrificing performance.

This post demonstrates how to use Amazon SageMaker to fine-tune a PyTorch BERT model and deploy it with Elastic Inference. The code from this post is available in the GitHub repo. For more information about BERT fine-tuning, see BERT Fine-Tuning Tutorial with PyTorch.

What is BERT?

First published in November 2018, BERT is a revolutionary model. First, one or more words in sentences are intentionally masked. BERT takes in these masked sentences as input and trains itself to predict the masked word. In addition, BERT uses a next sentence prediction task that pretrains text-pair representations.

BERT is a substantial breakthrough and has helped researchers and data engineers across the industry achieve state-of-art results in many NLP tasks. BERT offers representation of each word conditioned on its context (rest of the sentence). For more information about BERT, see BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.

BERT fine-tuning

One of the biggest challenges data scientists face for NLP projects is lack of training data; you often have only a few thousand pieces of human-labeled text data for your model training. However, modern deep learning NLP tasks require a large amount of labeled data. One way to solve this problem is to use transfer learning.

Transfer learning is an ML method where a pretrained model, such as a pretrained ResNet model for image classification, is reused as the starting point for a different but related problem. By reusing parameters from pretrained models, you can save significant amounts of training time and cost.

BERT was trained on BookCorpus and English Wikipedia data, which contains 800 million words and 2,500 million words, respectively [1]. Training BERT from scratch would be prohibitively expensive. By taking advantage of transfer learning, you can quickly fine-tune BERT for another use case with a relatively small amount of training data to achieve state-of-the-art results for common NLP tasks, such as text classification and question answering.

Solution overview

In this post, we walk through our dataset, the training process, and finally model deployment.

We use an Amazon SageMaker notebook instance for running the code. For more information about using Jupyter notebooks on Amazon SageMaker, see Using Amazon SageMaker Notebook Instances or Getting Started with Amazon SageMaker Studio.

The notebook and code from this post is available on GitHub. To run it yourself, clone the GitHub repository and open the Jupyter notebook file.

Problem and dataset

For this post, we use Corpus of Linguistic Acceptability (CoLA), a dataset of 10,657 English sentences labeled as grammatical or ungrammatical from published linguistics literature. In our notebook, we download and unzip the data using the following code:

if not os.path.exists("./"):
    !curl -o ./
if not os.path.exists("./cola_public/"):

In the training data, the only two columns we need are the sentence itself and its label:

df = pd.read_csv(
    usecols=[1, 3],
    names=["label", "sentence"],
sentences = df.sentence.values
labels = df.label.values

If we print out a few sentences, we can see how sentences are labeled based on their grammatical completeness. See the following code:


["The professor talked us." "We yelled ourselves hoarse."
 "We yelled ourselves." "We yelled Harry hoarse."
 "Harry coughed himself into a fit."]
[0 1 0 0 1]

We then split the dataset for training and testing before uploading both to Amazon S3 for use later. The SageMaker Python SDK provides a helpful function for uploading to Amazon S3:

from sagemaker.session import Session
from sklearn.model_selection import train_test_split

train, test = train_test_split(df)
train.to_csv("./cola_public/train.csv", index=False)
test.to_csv("./cola_public/test.csv", index=False)

session = Session()
inputs_train = session.upload_data("./cola_public/train.csv", key_prefix="sagemaker-bert/training/data")
inputs_test = session.upload_data("./cola_public/test.csv", key_prefix="sagemaker-bert/testing/data")

Training script

For this post, we use the PyTorch-Transformers library, which contains PyTorch implementations and pretrained model weights for many NLP models, including BERT. See the following code:

model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",  # Use the 12-layer BERT model, with an uncased vocab.
    num_labels=2,  # The number of output labels--2 for binary classification.
    output_attentions=False,  # Whether the model returns attentions weights.
    output_hidden_states=False,  # Whether the model returns all hidden-states.

Our training script should save model artifacts learned during training to a file path called model_dir, as stipulated by the Amazon SageMaker PyTorch image. Upon completion of training, Amazon SageMaker uploads model artifacts saved in model_dir to Amazon S3 so they are available for deployment. The following code is used in the script to save trained model artifacts:

model_2_save = model.module if hasattr(model, "module") else model

We save this script in a file named, and put the file in a directory named code/, where the full training script is viewable.

Because PyTorch-Transformer isn’t included natively in Amazon SageMaker PyTorch images, we have to provide a requirements.txt file so that Amazon SageMaker installs this library for training and inference. A requirements.txt file is a text file that contains a list of items that are installed by using pip install. You can also specify the version of an item to install. To install PyTorch-Transformer, we add the following line to the requirements.txt file:


You can view the entire file in the GitHub repo, and it also goes into the code/ directory. For more information about the format of a requirements.txt file, see Requirements Files.

Training on Amazon SageMaker

We use Amazon SageMaker to train and deploy a model using our custom PyTorch code. The Amazon SageMaker Python SDK makes it easier to run a PyTorch script in Amazon SageMaker using its PyTorch estimator. After that, we can use the SageMaker Python SDK to deploy the trained model and run predictions. For more information about using this SDK with PyTorch, see Using PyTorch with the SageMaker Python SDK.

To start, we use the PyTorch estimator class to train our model. When creating the estimator, we make sure to specify the following:

  • entry_point – The name of the PyTorch script
  • source_dir – The location of the training script and requirements.txt file
  • framework_version: The PyTorch version we want to use

The PyTorch estimator supports multi-machine, distributed PyTorch training. To use this, we just set train_instance_count to be greater than 1. Our training script supports distributed training for only GPU instances.

After creating the estimator, we call fit(), which launches a training job. We use the Amazon S3 URIs we uploaded the training data to earlier. See the following code:

from sagemaker.pytorch import PyTorch

estimator = PyTorch(
        "epochs": 1,
        "num_labels": 2,
        "backend": "gloo",
){"training": inputs_train, "testing": inputs_test})

After training starts, Amazon SageMaker displays training progress (as shown in the following code). Epochs, training loss, and accuracy on test data are reported:

2020-06-10 01:00:41 Starting - Starting the training job...
2020-06-10 01:00:44 Starting - Launching requested ML instances......
2020-06-10 01:02:04 Starting - Preparing the instances for training............
2020-06-10 01:03:48 Downloading - Downloading input data...
2020-06-10 01:04:15 Training - Downloading the training image..
2020-06-10 01:05:03 Training - Training image download completed. Training in progress.
Train Epoch: 1 [0/3207 (0%)] Loss: 0.626472
Train Epoch: 1 [350/3207 (98%)] Loss: 0.241283
Average training loss: 0.5248292144022736
Test set: Accuracy: 0.782608695652174

We can monitor the training progress and make sure it succeeds before proceeding with the rest of the notebook.

Deployment script

After training our model, we host it on an Amazon SageMaker endpoint by calling deploy on the PyTorch estimator. The endpoint runs an Amazon SageMaker PyTorch model server. We need to configure two components of the server: model loading and model serving. We implement these two components in our inference script The complete file is available in the GitHub repo.

model_fn() is the function defined to load the saved model and return a model object that can be used for model serving. The SageMaker PyTorch model server loads our model by invoking model_fn:

def model_fn(model_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BertForSequenceClassification.from_pretrained(model_dir)

input_fn() deserializes and prepares the prediction input. In this use case, our request body is first serialized to JSON and then sent to model serving endpoint. Therefore, in input_fn(), we first deserialize the JSON-formatted request body and return the input as a torch.tensor, as required for BERT:

def input_fn(request_body, request_content_type):
    if request_content_type == "application/json":
        sentence = json.loads(request_body)
        input_ids = []
        encoded_sent = tokenizer.encode(sentence,add_special_tokens = True)
        # pad shorter sentences
        input_ids_padded =[]
        for i in input_ids:
            while len(i) < MAX_LEN:
        input_ids = input_ids_padded
        # mask; 0: added, 1: otherwise
        [int(token_id > 0) for token_id in sent] for sent in input_ids

        # convert to PyTorch data types.
        train_inputs = torch.tensor(input_ids)
        train_masks = torch.tensor(attention_masks)
        # train_data = TensorDataset(train_inputs, train_masks)
        return train_inputs, train_masks

predict_fn() performs the prediction and returns the result. See the following code:

def predict_fn(input_data, model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_id, input_mask = input_data
    with torch.no_grad():
        return model(input_id, token_type_ids=None,attention_mask=input_mask)[0]

We take advantage of the prebuilt Amazon SageMaker PyTorch image’s default support for serializing the prediction result.

Deploying the endpoint

To deploy our endpoint, we call deploy() on our PyTorch estimator object, passing in our desired number of instances and instance type:

predictor = estimator.deploy(initial_instance_count=1, instance_type="ml.m4.xlarge")

We then configure the predictor to use "application/json" for the content type when sending requests to our endpoint:

from sagemaker.predictor import json_deserializer, json_serializer

predictor.content_type = "application/json"
predictor.accept = "application/json"
predictor.serializer = json_serializer
predictor.deserializer = json_deserializer

Finally, we use the returned predictor object to call the endpoint:

result = predictor.predict("Somebody just left - guess who.")
print(np.argmax(result, axis=1))


The predicted class is 1, which is expected because the test sentence is a grammatically correct sentence.

Deploying the endpoint with Elastic Inference

Selecting the right instance type for inference requires deciding between different amounts of GPU, CPU, and memory resources. Optimizing for one of these resources on a standalone GPU instance usually leads to underutilization of other resources. Elastic Inference solves this problem by enabling you to attach the right amount of GPU-powered inference acceleration to your endpoint. In March 2020, Elastic Inference support for PyTorch became available for both Amazon SageMaker and Amazon EC2.

To use Elastic Inference, we must first convert our trained model to TorchScript. For more information, see Reduce ML inference costs on Amazon SageMaker for PyTorch models using Amazon Elastic Inference.

We first download the trained model artifacts from Amazon S3. The location of the model artifacts is estimator.model_data. We then convert the model to TorchScript using the following code:

model_torchScript = BertForSequenceClassification.from_pretrained("model/", torchscript=True)
device = "cpu"
for_jit_trace_input_ids = [0] * 64
for_jit_trace_attention_masks = [0] * 64
for_jit_trace_input = torch.tensor([for_jit_trace_input_ids])
for_jit_trace_masks = torch.tensor([for_jit_trace_input_ids])

traced_model = torch.jit.trace(
    model_torchScript, [,]
), "")["tar", "-czvf", "traced_bert.tar.gz", ""])

Loading the TorchScript model and using it for prediction requires small changes in our model loading and prediction functions. We create a new script that is slightly different from script.

For model loading, we use torch.jit.load instead of the BertForSequenceClassification.from_pretrained call from before:

loaded_model = torch.jit.load(os.path.join(model_dir, ""))

For prediction, we take advantage of torch.jit.optimized_execution for the final return statement:

with torch.no_grad():
    with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
        return model(input_id,attention_mask=input_mask)[0]

The entire script is available in the GitHub repo. With this script, we can now deploy our model using Elastic Inference:

predictor = pytorch.deploy(

We attach the Elastic Inference accelerator to our output by using the accelerator_type="ml.eia2.xlarge" parameter.

Cleaning up resources

Remember to delete the Amazon SageMaker endpoint and Amazon SageMaker notebook instance created to avoid charges. See the following code:



In this post, we used Amazon SageMaker to take BERT as a starting point and train a model for labeling sentences on their grammatical completeness. We then deployed the model to an Amazon SageMaker endpoint, both with and without Elastic Inference acceleration. You can use this solution to tune BERT in other ways, or use other pretrained models provided by PyTorch-Transformers. For more about using PyTorch with Amazon SageMaker, see Using PyTorch with the SageMaker Python SDK.


[1] Yukun Zhu, Ryan Kiros, Rich Zemel, Ruslan Salakhutdinov, Raquel Urtasun, Antonio Torralba, and Sanja Fidler. 2015. Aligning books and movies: Towards story-like visual explanations by watching movies and reading books. In Proceedings of the IEEE international conference on computer vision, pages 19–27.

About the Authors

Qingwei Li is a Machine Learning Specialist at Amazon Web Services. He received his Ph.D. in Operations Research after he broke his advisor’s research grant account and failed to deliver the Nobel Prize he promised. Currently he helps customers in financial service and insurance industry build machine learning solutions on AWS. In his spare time, he likes reading and teaching.




David Ping is a Principal Solutions Architect with the AWS Solutions Architecture organization. He works with our customers to build cloud and machine learning solutions using AWS. He lives in the NY metro area and enjoys learning the latest machine learning technologies.




Lauren Yu is a Software Development Engineer at Amazon SageMaker. She works primarily on the SageMaker Python SDK, as well as toolkits for integrating PyTorch, TensorFlow, and MXNet with Amazon SageMaker. In her spare time, she enjoys playing viola in the Amazon Symphony Orchestra and Doppler Quartet.