Train your ML Models with AWS Trainium and Amazon SageMaker



In this tutorial, you’ll learn how to use Amazon SageMaker to train, a machine learning (ML) model using the AWS Trainium instances. Amazon EC2 Trn1 instances, powered by AWS Trainium accelerators, are purpose built for high-performance Deep learning (DL) training while offering up to 50% cost-to-train savings over comparable Amazon EC2 instances. Amazon SageMaker is a fully managed service that provides every developer and data scientist the ability to build, train, and deploy ML models at scale.

For this tutorial, you’ll use Amazon SageMaker Studio, an integrated development environment (IDE) for ML that provides a fully managed Jupyter notebook interface to build and run the training job using AWS Trainium. We will be building an BERT based Sentiment Analysis model using IMDB Data. BERT is a transformers model pretrained on a large corpus of English data in a self-supervised fashion. This model is primarily aimed at being fine-tuned on tasks that use the whole sentence (potentially masked) to make decisions, such as sequence classification, token classification, and question answering. The dataset consists of IMDB reviews which can be positive or negative. The goal of this exercise is to build a model which can be used to predict whether a given review is positive or negative.

What you will accomplish

In this guide, you will:

  • Train a BERT model for text classification using Amazon SageMaker Training on AWS Trainium (Trn1) instance


Before starting this guide, you will need:


 AWS experience


 Time to complete

15 minutes

 Cost to complete

See Amazon SageMaker pricing to estimate cost for this tutorial.


You must be logged into an AWS account.

 Services used

Amazon SageMaker Training

 Last updated

May 02, 2023

Step 1: Create an AWS account

An AWS account can have only one SageMaker Studio domain per AWS Region. If you already have a SageMaker Studio domain in the US East (N. Virginia) Region, follow the SageMaker Studio setup guide to attach the required AWS IAM policies to your SageMaker Studio account, then skip Step 1, and proceed directly to Step 2 to set up a SageMaker Studio notebook.

If you don't have an existing SageMaker Studio domain, continue with Step 1 to run an AWS CloudFormation template that creates a SageMaker Studio domain and adds the permissions required for the rest of this tutorial.

Choose the AWS CloudFormation stack link. This link opens the AWS CloudFormation console and creates your SageMaker Studio domain and a user named studio-user. It also adds the required permissions to your SageMaker Studio account. In the CloudFormation console, confirm that US East (N. Virginia) is the Region displayed in the upper right corner. Stack name should be CFN-SM-IM-Lambda-catalog, and should not be changed. This stack takes about 10 minutes to create all the resources.

This stack assumes that you already have a public VPC set up in your account. If you do not have a public VPC, see VPC with a single public subnet to learn how to create a public VPC.

Select I acknowledge that AWS CloudFormation might create IAM resources, and then choose Create stack.

On the CloudFormation pane, choose Stacks. When the stack is created, the status of the stack should change from CREATE_IN_PROGRESS to CREATE_COMPLETE.

Step 2: Set up a SageMaker Studio notebook

In this step, you'll launch a new SageMaker Studio notebook, install the necessary open source libraries, and set up the SageMaker variables required to interact with other services, including Amazon Simple Storage Service (Amazon S3).

Enter “SageMaker Studio” into the console search bar, and then choose SageMaker Studio.

Choose US East (N. Virginia) from the Region dropdown list on the upper right corner of the SageMaker console. For Launch app, select Studio to open SageMaker Studio using the studio-user profile.

Open the SageMaker Studio interface. On the navigation bar, choose File, New, Notebook.  

In the ‘Set up notebook environment’ dialog box, under Image, select Data Science 2.0. The Python 3 kernel is selected automatically. Choose Select.

The kernel on the top right corner of the notebook should now display Python 3 (Data Science 2.0).

Step 3 : Prepare the data

In this step, you use your Amazon SageMaker Studio notebook to preprocess the data that you need to train your machine learning model and then upload the data to Amazon S3.

We will use IMDB dataset from Hugging Face to train the model. In order to  download the dataset we need to install the datasets and transformers library. To install specific versions of these libraries, copy and paste the following code snippet into a cell in the notebook, and press Shift+Enter to run the current cell. Ignore any warnings to restart the kernel or any dependency conflict errors.

!pip install transformers==4.21.3 datasets==2.5.2

Add the below imports to the notebook cell 

from datasets import load_dataset
from datasets.filesystems import S3FileSystem
from import tqdm

Now we can use the load_dataset function to download the data from Hugging Face datasets. We will split the dataset into train and test. Copy paste the below code snippet into notebook cell.

dataset = load_dataset("imdb",split="train",ignore_verifications=True)
dataset = dataset.train_test_split()

Finally, we will upload the dataset into Amazon S3.  You can specify a s3 bucket to use, if not provided the code will use the default bucket. Copy paste the below code into notebook cell and execute. The cell output should print the role, bucket and region.  

import sagemaker
from sagemaker.pytorch import PyTorch

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

role = sagemaker.get_execution_role()

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

Now that we have a bucket, we will use the Hugging Face dataset API to upload the dataset to S3. Copy paste the below code to upload the dataset to S3.

s3 = S3FileSystem()  
s3_prefix = 'HFDatasets/imdb'

# save train_dataset to s3
training_input_path = f's3://{sagemaker_session_bucket}/{s3_prefix}'

Step 4 : Build training script

With SageMaker you can bring your own logic within Python scripts to be used for training. By encapsulating training logic in a script, you can incorporate custom training routines and model configurations while still using common ML framework containers such as PyTorch. In this tutorial, you will prepare a training script which uses the BERT Transformer model from Hugging Face Transformers library to train a text classification model using the IMDB data that we uploaded in the previous step. 

The first level of script mode is the ability to define your own training process in a self-contained, customized Python script and to use that script as the entry point when defining your SageMaker estimator. Copy and paste the following code block to write a Python script encapsulating the model training logic.


import argparse
import os
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
from datasets import load_from_disk
from torch.optim import AdamW
from import DataLoader
from import DistributedSampler
from import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_from_disk
# Initialize XLA process group for torchrun
import torch_xla.distributed.xla_backend
import random
import evaluate

device = "xla"
world_size = xm.xrt_world_size() 

def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")

        "--train_data", type=str, default=os.environ["SM_CHANNEL_TRAIN"])

        help="Path to pretrained model or model identifier from",

        help="Batch size (per device) for the training dataloader.",
        help="Batch size (per device) for the evaluation dataloader.",
        help="Initial learning rate (after the potential warmup period) to use.",
    parser.add_argument("--num_train_epochs", type=int, default=2, help="Total number of training epochs to perform.")
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",

    parser.add_argument("--output_dir", type=str, default=os.environ["SM_MODEL_DIR"], help="Where to store the final model.")
    parser.add_argument("--seed", type=int, default=100, help="A seed for reproducible training.")
    args = parser.parse_args()

    # Sanity checks
    if args.train_data is None:
        raise ValueError("Need a training file.")

    args.local_rank = int(os.environ["LOCAL_RANK"])
    args.world_rank = int(os.environ["RANK"])
    args.world_size = int(os.environ["WORLD_SIZE"])

    print("Local rank {} , World Rank {} , World Size {}".format(args.local_rank,args.world_rank,args.world_size))

    return args

def gather(tensor, name="gather tensor"):
    return xm.mesh_reduce(name, tensor,

def main():

    # Retrieve args passed to the training script
    args = parse_args()

    dataset = load_from_disk(args.train_data)

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    # tokenizer helper function
    def tokenize(batch):
        return tokenizer(batch['text'], max_length=args.max_length, padding='max_length', truncation=True)
    # load dataset
    train_dataset = dataset['train'].shuffle()
    eval_dataset = dataset['test'].shuffle()

    # tokenize dataset
    train_dataset =, batched=True)
    eval_dataset =, batched=True)


    # Log a few random samples from the training set:

    # set format for pytorch
    train_dataset =  train_dataset.rename_column("label", "labels")
    train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

    if args.world_rank == 0:
        for index in random.sample(range(len(train_dataset)), 3):
            print(f"Sample {index} of the training set: {train_dataset[index]}.")

    eval_dataset =  eval_dataset.rename_column("label", "labels")
    eval_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

    if args.world_rank == 0:
        for index in random.sample(range(len(eval_dataset)), 3):
            print(f"Sample {index} of the training set: {eval_dataset[index]}.")

    # Set up distributed data loader
    train_sampler = None
    if world_size > 1: # if more than one core
        train_sampler = DistributedSampler(
            num_replicas = args.world_size,
            rank = args.world_rank,
            shuffle = True,
    train_loader = DataLoader(
        batch_size = args.per_device_train_batch_size,
        shuffle=False if train_sampler else True,

    if world_size > 1: # if more than one core
        eval_sampler = DistributedSampler(
            num_replicas = args.world_size,
            rank = args.world_rank,
            shuffle = True,

    eval_loader = DataLoader(
        batch_size = args.per_device_eval_batch_size,
        shuffle=False if eval_sampler else True,

    train_device_loader = pl.MpDeviceLoader(train_loader, device)
    eval_device_loader = pl.MpDeviceLoader(eval_loader, device)
    num_training_steps = args.num_train_epochs * len(train_device_loader)
    progress_bar = tqdm(range(num_training_steps))

    model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)
    optimizer = AdamW(model.parameters(), lr=args.learning_rate)

    # Get the metric function
    metric = evaluate.load("accuracy")
    for epoch in range(args.num_train_epochs):
        for batch in train_device_loader:
            batch = {k: for k, v, in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            xm.optimizer_step(optimizer) #gather gradient updates from all cores and apply them
            if args.world_rank == 0:
        if args.world_rank == 0:
            "Epoch {}, rank {}, Loss {:0.4f}".format(epoch, args.world_rank, loss.detach().to("cpu"))
        # Run evaluation after each epochs
        if args.world_rank == 0:
            print("Running evaluation for the model")
        for eval_batch in eval_device_loader:
            with torch.no_grad():
                batch = {k: for k, v, in eval_batch.items()}
                outputs = model(**batch)
            predictions = outputs.logits.argmax(dim=-1)
            # Gather predictions and labels from all workers to compute accuracy.
            predictions = gather(predictions)
            references = gather(batch["labels"])
        eval_metric = metric.compute()
        if args.world_rank == 0:
            print(f"epoch {epoch} : Validation Accuracy -: {eval_metric}")

    # Save checkpoint for evaluation ( ensures only one process save)
    if args.output_dir is not None:, f"{args.output_dir}/")
        if args.world_rank == 0:
    if args.world_rank == 0:
        print('----------End Training ---------------')
if __name__ == '__main__':
  1.   In the training script, there are several important details worth mentioning:
    The smaller Trainium instance (trn1.2xlarge) contains 2 neuron cores and trn1.32xlarge consists of 32 neuron cores. To run the training efficiently we need a mechanism to distribute the training into available neuron cores. We use Pytorch XLA to achieve this. PyTorch/XLA -  is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud accelerators like AWS Trainium. Building a new PyTorch network or converting an existing one to run on XLA devices requires only a few lines of XLA-specific code.
    • GPU devices are replaced with Pytorch/XLA devices. Since we use torch distribution we need to initialize the training with XLA as the device as shown below.
    device = "xla"
  • PyTorch/XLA MpDeviceLoader is used for the data ingestion pipelines. Pytorch/XLA MpDeviceLoader helps improve performance by overlapping the three execution steps: tracing, compilation and data batch loading to the device. We need to wrap the PyTorch dataloader with the MpDeviceDataLoader as shown below
train_device_loader = pl.MpDeviceLoader(train_loader, "xla")
  • Run the optimizer step using the XLA provided API as shown below. This  consolidates the gradients between cores and issues the XLA device step computation.

2.    The data from S3 will be copied to the training instance and the path will be made available as environment variables under channel names SM_CHANNEL_TRAIN.
3.    The hyperparameters passed to the training job is made available as arguments. We will use Argparser to read the parameters in code as shown below

parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")

        "--train_data", type=str, default=os.environ["SM_CHANNEL_TRAIN"])

4.    Load the dataset from the channel path SM_CHANNEL_TRAIN using the dataset API

dataset = load_from_disk(args.train_data)

5.    The trained model config and weights are stored in a path provided by environment variable SM_MODEL_DIR. Amazon SageMaker will subsequently copy the files in SM_MODEL_DIR path to the S3 bucket once the training is complete. We can then use the model to deploy it to any hardware of our choice. We make sure to store the model files in the path provided by SM_MODEL_DIR environment variable.

SageMaker also provides a mechanism to install additional libraries required for training by providing a requirements.txt file along with the training scripts. For this example we need to install few additional libraries in order to use the transformers library. Copy paste and execute the below code snippet into notebook cell to create a requirements.txt file

%%writefile requirements.txt


We successfully created the training script and the requirements.txt file.  

Step 5: Train the ML model

We will start by defining some hyper parameters that we pass to the script. One thing to note here we pass the model_name_or_path from the Hugging Face Transformers library which specify the pretrained model that we will use to fine tune for text classification. Copy paste and execute the below code to notebook cell.  

base_job_name = "imdb-sentiment-classification"

hyperparameters = {}

hyperparameters["model_name_or_path"] = "bert-base-uncased"
hyperparameters["seed"] = 100
hyperparameters["max_length"] = 128
hyperparameters["per_device_train_batch_size"] = 8
hyperparameters["per_device_eval_batch_size"] = 8
hyperparameters["learning_rate"] = 5e-5
hyperparameters["max_train_steps"] = 2000
hyperparameters["num_train_epochs"] = 1

Next you will instantiate a SageMaker estimator. You will use the AWS managed PyTorch estimator to run your custom script. To instantiate the PyTorch estimator, copy and paste the following code.

pt_estimator = PyTorch(
    distribution={"torch_distributed": {"enabled": True}},

Couple of things to note here. The entry_point will point to the script we created on step 4.  Also note that there is a source_dir attribute pointing to current directory. This is required as we have the requirements.txt file which has to be copied and installed on the training instance. For the instance type, we specify AWS Trainium which comes in 2 flavours trn1.2xlarge and trn1.32xlarge. For this exercise, we will use the trn1.2xlarge instance.

Now that we have the Estimator created, we need to invoke the fit method with the S3 data path to start the training job.
Copy paste and execute the below cell to start the training job.{"train": training_input_path})

This will run the training using Trn1.2xlarge instance. You can see the training logs in the Studio notebook.


Congratulations! You have finished the Train your ML Models with AWS Trainium and Amazon SageMaker tutorial.

In this tutorial, you used Amazon SageMaker to train a BERT based classification model using AWS Trainium instance. AWS Trainium provides cost effective mechanism to train ML models at scale combine them with Amazon SageMaker ease of building ML models you should be able to experiment and scale quickly.

Was this page helpful?

Next steps

To learn about how to set up CloudWatch alarms in the Service Quotas console and how to create a quota request template, see the following resources.